diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 6cbc5b30cb..9d804dc53a 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -101,7 +101,8 @@ def __init__( Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object, - used for SageMaker interactions. Default: Session in region associated with boto3 session. + used for SageMaker interactions. Default: Session in region associated with boto3 + session. """ self._region = region @@ -358,7 +359,9 @@ def _retrieval_function( hub_content_type=data_type ) - model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True) + model_specs = JumpStartModelSpecs( + DescribeHubContentsResponse(hub_model_description), is_hub_content=True + ) utils.emit_logs_based_on_model_specs( model_specs, @@ -372,7 +375,9 @@ def _retrieval_function( hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) hub_description = DescribeHubResponse(response) - return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description)) + return JumpStartCachedContentValue( + formatted_content=DescribeHubResponse(hub_description) + ) raise ValueError( f"Bad value for key '{key}': must be in ", f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}" diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/__init__.py b/src/sagemaker/jumpstart/curated_hub/accessors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py new file mode 100644 index 0000000000..0393b4234a --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains important utilities related to HubContent data files.""" +from __future__ import absolute_import +from typing import Any, Dict, List + +from botocore.client import BaseClient + +from sagemaker.jumpstart.curated_hub.types import ( + FileInfo, + HubContentDependencyType, + S3ObjectLocation, +) +from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor +from sagemaker.jumpstart.types import JumpStartModelSpecs + + +def generate_file_infos_from_s3_location( + location: S3ObjectLocation, s3_client: BaseClient +) -> List[FileInfo]: + """Lists objects from an S3 bucket and formats into FileInfo. + + Returns a list of ``FileInfo`` objects from the specified bucket location. + """ + parameters = {"Bucket": location.bucket, "Prefix": location.key} + response = s3_client.list_objects_v2(**parameters) + contents = response.get("Contents") + + if not contents: + return [] + + files = [] + for s3_obj in contents: + key = s3_obj.get("Key") + size = s3_obj.get("Size") + last_modified = s3_obj.get("LastModified") + files.append(FileInfo(location.bucket, key, size, last_modified)) + return files + + +def generate_file_infos_from_model_specs( + model_specs: JumpStartModelSpecs, + studio_specs: Dict[str, Any], + region: str, + s3_client: BaseClient, +) -> List[FileInfo]: + """Collects data locations from JumpStart public model specs and converts into `FileInfo`. + + Returns a list of `FileInfo` objects from dependencies found in the public + model specs. + """ + public_model_data_accessor = PublicModelDataAccessor( + region=region, model_specs=model_specs, studio_specs=studio_specs + ) + files = [] + for dependency in HubContentDependencyType: + location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) + location_type = "prefix" if location.key.endswith("/") else "object" + + if location_type == "prefix": + parameters = {"Bucket": location.bucket, "Prefix": location.key} + response = s3_client.list_objects_v2(**parameters) + contents = response.get("Contents") + for s3_obj in contents: + key = s3_obj.get("Key") + size = s3_obj.get("Size") + last_modified = s3_obj.get("LastModified") + files.append( + FileInfo( + location.bucket, + key, + size, + last_modified, + dependency, + ) + ) + elif location_type == "object": + parameters = {"Bucket": location.bucket, "Key": location.key} + response = s3_client.head_object(**parameters) + size = response.get("ContentLength") + last_updated = response.get("LastModified") + files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency)) + return files diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py new file mode 100644 index 0000000000..52174f9eeb --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" +from __future__ import absolute_import +from typing import Optional + +import boto3 +import botocore +import tqdm + +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.curated_hub.types import FileInfo +from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequest + +s3transfer = boto3.s3.transfer + + +# pylint: disable=R1705,R1710 +def human_readable_size(value: int) -> str: + """Convert a size in bytes into a human readable format. + + For example:: + + >>> human_readable_size(1) + '1 Byte' + >>> human_readable_size(10) + '10 Bytes' + >>> human_readable_size(1024) + '1.0 KiB' + >>> human_readable_size(1024 * 1024) + '1.0 MiB' + + :param value: The size in bytes. + :return: The size in a human readable format based on base-2 units. + + """ + base = 1024 + bytes_int = float(value) + + if bytes_int == 1: + return "1 Byte" + elif bytes_int < base: + return "%d Bytes" % bytes_int + + for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")): + unit = base ** (i + 2) + if round((bytes_int / unit) * base) < base: + return "%.1f %s" % ((base * bytes_int / unit), suffix) + + +class MultiPartCopyHandler(object): + """Multi Part Copy Handler class.""" + + WORKERS = 20 + # Config values from in S3:Copy + MULTIPART_CONFIG = 8 * (1024**2) + + def __init__( + self, + region: str, + sync_request: HubSyncRequest, + label: Optional[str] = None, + thread_num: Optional[int] = 0, + ): + """Multi-part S3:Copy Handler initializer. + + Args: + region (str): Region for the S3 Client + sync_request (HubSyncRequest): sync request object containing + information required to perform the copy + """ + self.label = label + self.region = region + self.files = sync_request.files + self.dest_location = sync_request.destination + self.thread_num = thread_num + + config = botocore.config.Config(max_pool_connections=self.WORKERS) + self.s3_client = boto3.client("s3", region_name=self.region, config=config) + transfer_config = s3transfer.TransferConfig( + multipart_threshold=self.MULTIPART_CONFIG, + multipart_chunksize=self.MULTIPART_CONFIG, + max_bandwidth=True, + use_threads=True, + max_concurrency=self.WORKERS, + ) + self.transfer_manager = s3transfer.create_transfer_manager( + client=self.s3_client, config=transfer_config + ) + + def _copy_file(self, file: FileInfo, progress_cb): + """Performs the actual MultiPart S3:Copy of the object.""" + copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} + result = self.transfer_manager.copy( + bucket=self.dest_location.bucket, + key=f"{self.dest_location.key}/{file.location.key}", + copy_source=copy_source, + subscribers=[ + s3transfer.ProgressCallbackInvoker(progress_cb), + ], + ) + # Attempt to access result to throw error if exists. Silently calls if successful. + result.result() + + def execute(self): + """Executes the MultiPart S3:Copy on the class. + + Sets up progress bar and kicks off each copy request. + """ + total_size = sum([file.size for file in self.files]) + JUMPSTART_LOGGER.warning( + "Copying %s files (%s) into %s/%s", + len(self.files), + human_readable_size(total_size), + self.dest_location.bucket, + self.dest_location.key, + ) + + progress = tqdm.tqdm( + desc=self.label, + total=total_size, + unit="B", + unit_scale=1, + position=self.thread_num, + bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}", + ) + + for file in self.files: + self._copy_file(file, progress.update) + + # Call `shutdown` to wait for copy results + self.transfer_manager.shutdown() + progress.close() diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py new file mode 100644 index 0000000000..89e3a2f108 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -0,0 +1,116 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module accessors for the SageMaker JumpStart Public Hub.""" +from __future__ import absolute_import +from typing import Dict, Any +from sagemaker import model_uris, script_uris +from sagemaker.jumpstart.curated_hub.types import ( + HubContentDependencyType, + S3ObjectLocation, +) +from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.types import JumpStartModelSpecs +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + + +class PublicModelDataAccessor: + """Accessor class for JumpStart model data s3 locations.""" + + def __init__( + self, + region: str, + model_specs: JumpStartModelSpecs, + studio_specs: Dict[str, Dict[str, Any]], + ): + self._region = region + self._bucket = get_jumpstart_content_bucket(region) + self.model_specs = model_specs + self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift + + def get_s3_reference(self, dependency_type: HubContentDependencyType): + """Retrieves S3 reference given a HubContentDependencyType.""" + return getattr(self, dependency_type.value) + + @property + def inference_artifact_s3_reference(self): + """Retrieves s3 reference for model inference artifact""" + return create_s3_object_reference_from_uri( + self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) + ) + + @property + def training_artifact_s3_reference(self): + """Retrieves s3 reference for model training artifact""" + return create_s3_object_reference_from_uri( + self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) + ) + + @property + def inference_script_s3_reference(self): + """Retrieves s3 reference for model inference script""" + return create_s3_object_reference_from_uri( + self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) + ) + + @property + def training_script_s3_reference(self): + """Retrieves s3 reference for model training script""" + return create_s3_object_reference_from_uri( + self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) + ) + + @property + def default_training_dataset_s3_reference(self): + """Retrieves s3 reference for s3 directory containing model training datasets""" + return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) + + @property + def demo_notebook_s3_reference(self): + """Retrieves s3 reference for model demo jupyter notebook""" + framework = self.model_specs.get_framework() + key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" + return S3ObjectLocation(self._get_bucket_name(), key) + + @property + def markdown_s3_reference(self): + """Retrieves s3 reference for model markdown""" + framework = self.model_specs.get_framework() + key = f"{framework}-metadata/{self.model_specs.model_id}.md" + return S3ObjectLocation(self._get_bucket_name(), key) + + def _get_bucket_name(self) -> str: + """Retrieves s3 bucket""" + return self._bucket + + def __get_training_dataset_prefix(self) -> str: + """Retrieves training dataset location""" + return self.studio_specs["defaultDataKey"] + + def _jumpstart_script_s3_uri(self, model_scope: str) -> str: + """Retrieves JumpStart script s3 location""" + return script_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + script_scope=model_scope, + ) + + def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: + """Retrieves JumpStart artifact s3 location""" + return model_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + model_scope=model_scope, + ) diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py new file mode 100644 index 0000000000..5d35aa80c6 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/constants.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import + +JUMPSTART_HUB_MODEL_ID_TAG_PREFIX = "@jumpstart-model-id" +JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX = "@jumpstart-model-version" +FRAMEWORK_TAG_PREFIX = "@framework" +TASK_TAG_PREFIX = "@mltask" diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 59a11df577..a35948f138 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -12,17 +12,47 @@ # language governing permissions and limitations under the License. """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import +from concurrent import futures +from datetime import datetime +import json +import traceback +from typing import Optional, Dict, List, Any +import boto3 +from botocore import exceptions +from botocore.client import BaseClient +from packaging.version import Version -from typing import Any, Dict, Optional +from sagemaker.jumpstart import utils +from sagemaker.jumpstart.curated_hub.accessors import file_generator +from sagemaker.jumpstart.curated_hub.accessors.multipartcopy import MultiPartCopyHandler +from sagemaker.jumpstart.curated_hub.constants import ( + JUMPSTART_HUB_MODEL_ID_TAG_PREFIX, + JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX, + TASK_TAG_PREFIX, + FRAMEWORK_TAG_PREFIX, +) +from sagemaker.jumpstart.curated_hub.sync.comparator import SizeAndLastUpdatedComparator +from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequestFactory +from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.types import ( DescribeHubResponse, DescribeHubContentsResponse, HubContentType, + JumpStartModelSpecs, +) +from sagemaker.jumpstart.curated_hub.utils import ( + create_hub_bucket_if_it_does_not_exist, + generate_default_hub_bucket_name, + create_s3_object_reference_from_uri, +) +from sagemaker.jumpstart.curated_hub.types import ( + HubContentDocument_v2, + JumpStartModelInfo, + S3ObjectLocation, ) -from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist class CuratedHub: @@ -31,8 +61,9 @@ class CuratedHub: def __init__( self, hub_name: str, + bucket_name: Optional[str] = None, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - ): + ) -> None: """Instantiates a SageMaker ``CuratedHub``. Args: @@ -43,25 +74,63 @@ def __init__( self.hub_name = hub_name self.region = sagemaker_session.boto_region_name self._sagemaker_session = sagemaker_session + self._default_thread_pool_size = 20 + self._s3_client = self._get_s3_client() + self.hub_storage_location = self._generate_hub_storage_location(bucket_name) + + def _get_s3_client(self) -> BaseClient: + """Returns an S3 client used for creating a HubContentDocument.""" + return boto3.client("s3", region_name=self.region) + + def _fetch_hub_bucket_name(self) -> str: + """Retrieves hub bucket name from Hub config if exists""" + try: + hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) + hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") + if hub_output_location: + location = create_s3_object_reference_from_uri(hub_output_location) + return location.bucket + default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + default_bucket_name, + ) + return default_bucket_name + except exceptions.ClientError: + hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + hub_bucket_name, + ) + return hub_bucket_name + + def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: + """Generates an ``S3ObjectLocation`` given a Hub name.""" + hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() + curr_timestamp = datetime.now().timestamp() + return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") def create( self, description: str, display_name: Optional[str] = None, search_keywords: Optional[str] = None, - bucket_name: Optional[str] = None, tags: Optional[str] = None, ) -> Dict[str, str]: """Creates a hub with the given description""" - bucket_name = create_hub_bucket_if_it_does_not_exist(bucket_name, self._sagemaker_session) + create_hub_bucket_if_it_does_not_exist( + self.hub_storage_location.bucket, self._sagemaker_session + ) return self._sagemaker_session.create_hub( hub_name=self.hub_name, hub_description=description, hub_display_name=display_name, hub_search_keywords=search_keywords, - hub_bucket_name=bucket_name, + s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, tags=tags, ) @@ -113,3 +182,236 @@ def delete_model(self, model_name: str, model_version: str = "*") -> None: def delete(self) -> None: """Deletes this Curated Hub""" return self._sagemaker_session.delete_hub(self.hub_name) + + def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool: + """Determines if input args to ``sync`` is correct. + + `model_list` objects must have `model_id` (str) and optional `version` (str). + """ + for obj in model_list: + if not isinstance(obj.get("model_id"), str): + return True + if "version" in obj and not isinstance(obj["version"], str): + return True + return False + + def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str]: + """Populates the lastest version of a model from specs no matter what is passed. + + Returns model ({ model_id: str, version: str }) + """ + model_specs = utils.verify_model_region_and_return_specs( + model["model_id"], "*", JumpStartScriptScope.INFERENCE, self.region + ) + return {"model_id": model["model_id"], "version": model_specs.version} + + def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: + """Returns list of `HubContent` that have been created from a JumpStart model.""" + hub_models = self.list_models() + + js_models_in_hub = [] + for hub_model in hub_models["HubContentSummaries"]: + # TODO: extract both in one pass + jumpstart_model_id = next( + ( + tag + for tag in hub_model["search_keywords"] + if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX) + ), + None, + ) + jumpstart_model_version = next( + ( + tag + for tag in hub_model["search_keywords"] + if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX) + ), + None, + ) + + if jumpstart_model_id and jumpstart_model_version: + js_models_in_hub.append(hub_model) + + return js_models_in_hub + + def _determine_models_to_sync( + self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, Any] + ) -> List[JumpStartModelInfo]: + """Determines which models from `sync` params to sync into the CuratedHub. + + Algorithm: + + First, look for a match of model name in Hub. If no model is found, sync that model. + + Next, compare versions of model to sync and what's in the Hub. If version already + in Hub, don't sync. If newer version in Hub, don't sync. If older version in Hub, + sync that model. + """ + models_to_sync = [] + for model in model_list: + matched_model = models_in_hub.get(model.model_id) + + # Model does not exist in Hub, sync + if not matched_model: + models_to_sync.append(model) + + if matched_model: + model_version = Version(model.version) + hub_model_version = Version(matched_model["version"]) + + # 1. Model version exists in Hub, pass + if hub_model_version == model_version: + pass + + # 2. Invalid model version exists in Hub, pass + # This will only happen if something goes wrong in our metadata + if hub_model_version > model_version: + pass + + # 3. Old model version exists in Hub, update + if hub_model_version < model_version: + # Check minSDKVersion against current SDK version, emit log + models_to_sync.append(model) + + return models_to_sync + + def sync(self, model_list: List[Dict[str, str]]): + """Syncs a list of JumpStart model ids and versions with a CuratedHub + + Args: + model_list (List[Dict[str, str]]): List of `{ model_id: str, version: Optional[str] }` + objects that should be synced into the Hub. + """ + if self._is_invalid_model_list_input(model_list): + raise ValueError( + "Model list should be a list of objects with values 'model_id',", + "and optional 'version'.", + ) + + # Retrieve latest version of unspecified JumpStart model versions + model_version_list = [] + for model in model_list: + version = model.get("version", "*") + if version == "*": + model = self._populate_latest_model_version(model) + JUMPSTART_LOGGER.warning( + "No version specified for model %s. Using version %s", + model["model_id"], + model["version"], + ) + model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"])) + + js_models_in_hub = self._get_jumpstart_models_in_hub() + mapped_models_in_hub = {model["name"]: model for model in js_models_in_hub} + + models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub) + JUMPSTART_LOGGER.warning( + "Syncing the following models into Hub %s: %s", self.hub_name, models_to_sync + ) + + # Delete old models? + + # CopyContentWorkflow + `SageMaker:ImportHubContent` for each model-to-sync in parallel + tasks: List[futures.Future] = [] + with futures.ThreadPoolExecutor( + max_workers=self._default_thread_pool_size, + thread_name_prefix="import-models-to-curated-hub", + ) as import_executor: + for thread_num, model in enumerate(models_to_sync): + task = import_executor.submit(self._sync_public_model_to_hub, model, thread_num) + tasks.append(task) + + # Handle failed imports + results = futures.wait(tasks) + failed_imports: List[Dict[str, Any]] = [] + for result in results.done: + exception = result.exception() + if exception: + failed_imports.append( + { + "Exception": exception, + "Traceback": "".join( + traceback.TracebackException.from_exception(exception).format() + ), + } + ) + if failed_imports: + raise RuntimeError( + f"Failures when importing models to curated hub in parallel: {failed_imports}" + ) + + def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int): + """Syncs a public JumpStart model version to the Hub. Runs in parallel.""" + model_specs = utils.verify_model_region_and_return_specs( + model_id=model.model_id, + version=model.version, + region=self.region, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self._sagemaker_session, + ) + studio_specs = self._fetch_studio_specs(model_specs=model_specs) + + dest_location = S3ObjectLocation( + bucket=self.hub_storage_location.bucket, + key=f"{self.hub_storage_location.key}/curated_models/{model.model_id}/{model.version}", + ) + src_files = file_generator.generate_file_infos_from_model_specs( + model_specs, studio_specs, self.region, self._s3_client + ) + dest_files = file_generator.generate_file_infos_from_s3_location( + dest_location, self._s3_client + ) + + comparator = SizeAndLastUpdatedComparator() + sync_request = HubSyncRequestFactory( + src_files, dest_files, dest_location, comparator + ).create() + + if len(sync_request.files) > 0: + MultiPartCopyHandler( + thread_num=thread_num, + sync_request=sync_request, + region=self.region, + label=dest_location.key, + ).execute() + else: + JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version) + + # TODO: Tag model if specs say it is deprecated or training/inference + # vulnerable. Update tag of HubContent ARN without version. + # Versioned ARNs are not onboarded to Tagris. + tags = [] + + search_keywords = [ + f"{JUMPSTART_HUB_MODEL_ID_TAG_PREFIX}:{model.model_id}", + f"{JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX}:{model.version}", + f"{FRAMEWORK_TAG_PREFIX}:{model_specs.get_framework()}", + f"{TASK_TAG_PREFIX}:TODO: pull from specs", + ] + + hub_content_document = str(HubContentDocument_v2(spec=model_specs)) + + self._sagemaker_session.import_hub_content( + document_schema_version=HubContentDocument_v2.SCHEMA_VERSION, + hub_content_name=model.model_id, + hub_content_version=model.version, + hub_name=self.hub_name, + hub_content_document=hub_content_document, + hub_content_type=HubContentType.MODEL, + hub_content_display_name="", + hub_content_description="", + hub_content_markdown="", + hub_content_search_keywords=search_keywords, + tags=tags, + ) + + def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any]: + """Fetches StudioSpecs given a model's SDK Specs.""" + model_id = model_specs.model_id + model_version = model_specs.version + + key = utils.generate_studio_spec_file_prefix(model_id, model_version) + response = self._s3_client.get_object( + Bucket=utils.get_jumpstart_content_bucket(self.region), Key=key + ) + return json.loads(response["Body"].read().decode("utf-8")) diff --git a/src/sagemaker/jumpstart/curated_hub/sync/comparator.py b/src/sagemaker/jumpstart/curated_hub/sync/comparator.py new file mode 100644 index 0000000000..ed7c1b9269 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/sync/comparator.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides comparators for syncing s3 files.""" +from __future__ import absolute_import +from datetime import timedelta +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER + +from sagemaker.jumpstart.curated_hub.types import FileInfo + + +class BaseComparator: + """BaseComparator object to be extended.""" + + def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool: + """Custom comparator to determine if src file and dest file are in sync.""" + raise NotImplementedError + + +class SizeAndLastUpdatedComparator(BaseComparator): + """Comparator that uses file size and last modified time. + + Uses file size (bytes) and last_modified_time (timestamp) to determine sync. + """ + + def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool: + """Determines if src file should be moved to dest folder.""" + same_size = self.compare_size(src_file, dest_file) + is_newer_dest_file = self.compare_file_updates(src_file, dest_file) + should_sync = (not same_size) or (not is_newer_dest_file) + if should_sync: + JUMPSTART_LOGGER.warning( + "syncing: %s -> %s, size: %s -> %s, modified time: %s -> %s", + src_file.location.key, + src_file.location.key, + src_file.size, + dest_file.size, + src_file.last_updated, + dest_file.last_updated, + ) + return should_sync + + def compare_size(self, src_file: FileInfo, dest_file: FileInfo): + """Compares sizes of src and dest files. + + :returns: True if the sizes are the same. + False otherwise. + """ + return src_file.size == dest_file.size + + def compare_file_updates(self, src_file: FileInfo, dest_file: FileInfo): + """Compares time delta between src and dest files. + + :returns: True if the file does not need updating based on time of + last modification and type of operation. + False if the file does need updating based on the time of + last modification and type of operation. + """ + src_time = src_file.last_updated + dest_time = dest_file.last_updated + delta = dest_time - src_time + # pylint: disable=R1703,R1705 + if timedelta.total_seconds(delta) >= 0: + return True + else: + # Destination is older than source, so + # we have a more recently updated file + # at the source location. + return False diff --git a/src/sagemaker/jumpstart/curated_hub/sync/request.py b/src/sagemaker/jumpstart/curated_hub/sync/request.py new file mode 100644 index 0000000000..0e620432ce --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/sync/request.py @@ -0,0 +1,148 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" +from __future__ import absolute_import +from dataclasses import dataclass +from typing import Generator, List + +from botocore.compat import six + +from sagemaker.jumpstart.curated_hub.sync.comparator import BaseComparator +from sagemaker.jumpstart.curated_hub.types import FileInfo, S3ObjectLocation + +advance_iterator = six.advance_iterator + + +@dataclass +class HubSyncRequest: + """HubSyncRequest class""" + + files: List[FileInfo] + destination: S3ObjectLocation + + def __init__( + self, files_to_copy: Generator[FileInfo, FileInfo, FileInfo], destination: S3ObjectLocation + ): + """Contains information required to sync data into a Hub. + + Attrs: + files (List[FileInfo]): Files that should be synced. + destination (S3ObjectLocation): Location to which to sync the files. + """ + self.files = list(files_to_copy) + self.destination = destination + + +class HubSyncRequestFactory: + """Generates a ``HubSyncRequest`` which is required to sync data into a Hub. + + Creates a ``HubSyncRequest`` class containing: + :var: files (List[FileInfo]): Files that should be synced. + :var: destination (S3ObjectLocation): Location to which to sync the files. + """ + + def __init__( + self, + src_files: List[FileInfo], + dest_files: List[FileInfo], + destination: S3ObjectLocation, + comparator: BaseComparator, + ): + """Instantiates a ``HubSyncRequestFactory`` class. + + Args: + src_files (List[FileInfo]): List of files to sync to destination bucket + dest_files (List[FileInfo]): List of files already in destination bucket + destination (S3ObjectLocation): S3 destination for copied data + """ + self.comparator = comparator + self.destination = destination + # Need the file lists to be sorted for comparisons below + self.src_files: List[FileInfo] = sorted(src_files, key=lambda x: x.location.key) + formatted_dest_files = [self._format_dest_file(file) for file in dest_files] + self.dest_files: List[FileInfo] = sorted(formatted_dest_files, key=lambda x: x.location.key) + + def _format_dest_file(self, file: FileInfo) -> FileInfo: + """Strips HubContent data prefix from dest file name""" + formatted_key = file.location.key.replace(f"{self.destination.key}/", "") + file.location.key = formatted_key + return file + + def create(self) -> HubSyncRequest: + """Creates a ``HubSyncRequest`` object, which contains `files` to copy and the `destination` + + Based on the `s3:sync` algorithm. + """ + files_to_copy = self._determine_files_to_copy() + return HubSyncRequest(files_to_copy, self.destination) + + def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: + """This function performs the actual comparisons. Returns a list of FileInfo to copy. + + Algorithm: + Loop over sorted files in the src directory. If at the end of dest directory, + add the src file to be copied and continue. Else, take the first file in the sorted + dest directory. If there is no dest file, signal we're at the end of dest and continue. + If there is a dest file, compare file names. If the file names are equivalent, + use the comparator to see if the dest file should be updated. If the file + names are not equivalent, increment the dest pointer. + + Notes: + Increment src_files = continue, Increment dest_files = advance_iterator(iterator), + Take src_file = yield + """ + # :var dest_done: True if there are no files from the dest left. + dest_done = False + iterator = iter(self.dest_files) + # Begin by advancing the iterator to the first file + try: + dest_file: FileInfo = advance_iterator(iterator) + except StopIteration: + dest_done = True + + for src_file in self.src_files: + # End of dest, yield remaining src_files + if dest_done: + yield src_file + continue + + # We've identified two files that have the same name, further compare + if self._is_same_file_name(src_file.location.key, dest_file.location.key): + should_sync = self.comparator.determine_should_sync(src_file, dest_file) + + if should_sync: + yield src_file + + # Increment dest_files and src_files + try: + dest_file: FileInfo = advance_iterator(iterator) + except StopIteration: + dest_done = True + continue + + # Past the src file alphabetically in dest file list. Take the src file and increment src_files. + # If there is an alpha-larger file name in dest as compared to src, it means there is an + # unexpected file in dest. Do nothing and continue to the next src_file + if self._is_alphabetically_earlier_file_name( + src_file.location.key, dest_file.location.key + ): + yield src_file + continue + + def _is_same_file_name(self, src_filename: str, dest_filename: str) -> bool: + """Determines if two files have the same file name.""" + return src_filename == dest_filename + + def _is_alphabetically_earlier_file_name(self, src_filename: str, dest_filename: str) -> bool: + """Determines if one filename is alphabetically earlier than another.""" + return src_filename < dest_filename diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py new file mode 100644 index 0000000000..99f9cfdc63 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import +from typing import Dict, Any, Optional +from enum import Enum +from dataclasses import dataclass +from datetime import datetime + +from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs + + +@dataclass +class S3ObjectLocation: + """Helper class for S3 object references""" + + bucket: str + key: str + + def format_for_s3_copy(self) -> Dict[str, str]: + """Returns a dict formatted for S3 copy calls""" + return { + "Bucket": self.bucket, + "Key": self.key, + } + + def get_uri(self) -> str: + """Returns the s3 URI""" + return f"s3://{self.bucket}/{self.key}" + + +@dataclass +class JumpStartModelInfo: + """Helper class for storing JumpStart model info.""" + + model_id: str + version: str + + +class HubContentDependencyType(str, Enum): + """Enum class for HubContent dependency names""" + + INFERENCE_ARTIFACT = "inference_artifact_s3_reference" + TRAINING_ARTIFACT = "training_artifact_s3_reference" + INFERENCE_SCRIPT = "inference_script_s3_reference" + TRAINING_SCRIPT = "training_script_s3_reference" + DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference" + DEMO_NOTEBOOK = "demo_notebook_s3_reference" + MARKDOWN = "markdown_s3_reference" + + +class FileInfo(JumpStartDataHolderType): + """Data class for additional S3 file info.""" + + location: S3ObjectLocation + + def __init__( + self, + bucket: str, + key: str, + size: Optional[bytes], + last_updated: Optional[datetime], + dependecy_type: Optional[HubContentDependencyType] = None, + ): + self.location = S3ObjectLocation(bucket, key) + self.size = size + self.last_updated = last_updated + self.dependecy_type = dependecy_type + + +class HubContentDocument_v2(JumpStartDataHolderType): + """Data class for HubContentDocument v2.0.0""" + + SCHEMA_VERSION = "2.0.0" + + def __init__(self, spec: Dict[str, Any]): + """Initializes a HubContentDocument_v2 object from JumpStart model specs. + + Args: + spec (Dict[str, Any]): Dictionary representation of spec. + """ + self.from_specs(spec) + + def from_specs(self, model_specs: JumpStartModelSpecs) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representatino of spec. + """ + # TODO: Implement + self.Url: str = model_specs.url diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index ac01da45ca..b116411801 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -14,6 +14,8 @@ from __future__ import absolute_import import re from typing import Optional +from sagemaker.jumpstart.curated_hub.types import S3ObjectLocation +from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session from sagemaker.utils import aws_partition from sagemaker.jumpstart.types import ( @@ -131,6 +133,16 @@ def generate_default_hub_bucket_name( return f"sagemaker-hubs-{region}-{account_id}" +def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: + """Utiity to help generate an S3 object reference""" + bucket, key = parse_s3_url(s3_uri) + + return S3ObjectLocation( + bucket=bucket, + key=key, + ) + + def create_hub_bucket_if_it_does_not_exist( bucket_name: Optional[str] = None, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9c457a5626..cd10a7123b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -964,6 +964,10 @@ def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported + def get_framework(self) -> str: + """Returns the framework for the model.""" + return self.model_id.split("-")[0] + class JumpStartVersionedModelId(JumpStartDataHolderType): """Data class for versioned model IDs.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1e2bb11d45..f7375b3027 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -829,3 +829,8 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version + + +def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str: + """Returns the Studio Spec file prefix given a model ID and version.""" + return f"studio_models/{model_id}/studio_specs_v{model_version}.json" diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 2448721520..16a5588094 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -11,14 +11,27 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from copy import deepcopy +import datetime +from unittest import mock +from unittest.mock import patch import pytest from mock import Mock from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub +from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo, S3ObjectLocation +from sagemaker.jumpstart.types import JumpStartModelSpecs +from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec + REGION = "us-east-1" ACCOUNT_ID = "123456789123" HUB_NAME = "mock-hub-name" +MODULE_PATH = "sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub" + +FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00) + @pytest.fixture() def sagemaker_session(): @@ -29,6 +42,9 @@ def sagemaker_session(): sagemaker_session_mock._client_config.user_agent = ( "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" ) + sagemaker_session_mock.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": "s3://mock-bucket-123"} + } sagemaker_session_mock.account_id.return_value = ACCOUNT_ID return sagemaker_session_mock @@ -54,7 +70,9 @@ def test_instantiates(sagemaker_session): ), ], ) +@patch("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location") def test_create_with_no_bucket_name( + mock_generate_hub_storage_location, sagemaker_session, hub_name, hub_description, @@ -63,21 +81,29 @@ def test_create_with_no_bucket_name( hub_search_keywords, tags, ): + storage_location = S3ObjectLocation( + "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" + ) + mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) + sagemaker_session.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} + } hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, "hub_description": hub_description, - "hub_bucket_name": "sagemaker-hubs-us-east-1-123456789123", "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, + "s3_storage_config": { + "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" + }, "tags": tags, } response = hub.create( description=hub_description, display_name=hub_display_name, - bucket_name=hub_bucket_name, search_keywords=hub_search_keywords, tags=tags, ) @@ -99,7 +125,9 @@ def test_create_with_no_bucket_name( ), ], ) +@patch("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location") def test_create_with_bucket_name( + mock_generate_hub_storage_location, sagemaker_session, hub_name, hub_description, @@ -108,23 +136,369 @@ def test_create_with_bucket_name( hub_search_keywords, tags, ): + storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") + mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + hub = CuratedHub( + hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name + ) request = { "hub_name": hub_name, "hub_description": hub_description, - "hub_bucket_name": hub_bucket_name, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, + "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, "tags": tags, } response = hub.create( description=hub_description, display_name=hub_display_name, - bucket_name=hub_bucket_name, search_keywords=hub_search_keywords, tags=tags, ) sagemaker_session.create_hub.assert_called_with(**request) assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_kicks_off_parallel_syncs( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = {"HubContentSummaries": []} + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_has_calls( + [ + mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0), + mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1), + ] + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_filters_models_that_exist_in_hub( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + {"name": "mock-model-three-nonsense", "version": "1.0.2", "search_keywords": []}, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_called_once_with( + JumpStartModelInfo("mock-model-one-huggingface", "*"), 0 + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_updates_old_models_in_hub( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.1", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_has_calls( + [ + mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0), + mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1), + ] + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_passes_newer_hub_models( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_called_once_with( + JumpStartModelInfo("mock-model-one-huggingface", "*"), 0 + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_populate_latest_model_version(mock_get_model_specs, sagemaker_session): + mock_get_model_specs.return_value = JumpStartModelSpecs(deepcopy(BASE_SPEC)) + + hub_name = "mock_hub_name" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + res = hub._populate_latest_model_version({"model_id": "mock-pytorch-model-one", "version": "*"}) + assert res == {"model_id": "mock-pytorch-model-one", "version": "1.0.0"} + + res = hub._populate_latest_model_version({"model_id": "mock-pytorch-model-one"}) + assert res == {"model_id": "mock-pytorch-model-one", "version": "1.0.0"} + + # Should take latest version from specs no matter what. Parent should responsibly call + res = hub._populate_latest_model_version( + {"model_id": "mock-pytorch-model-one", "version": "2.0.0"} + ) + assert res == {"model_id": "mock-pytorch-model-one", "version": "1.0.0"} + + +@patch(f"{MODULE_PATH}.list_models") +def test_get_jumpstart_models_in_hub(mock_list_models, sagemaker_session): + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.3", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } + + hub_name = "mock_hub_name" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + res = hub._get_jumpstart_models_in_hub() + assert res == [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.3", + ], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + + mock_list_models.return_value = {"HubContentSummaries": []} + + res = hub._get_jumpstart_models_in_hub() + assert res == [] + + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + ] + } + + res = hub._get_jumpstart_models_in_hub() + assert res == [] + + +def test_determine_models_to_sync(sagemaker_session): + hub_name = "mock_hub_name" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + js_model_map = { + "mock-model-two-pytorch": { + "name": "mock-model-two-pytorch", + "version": "1.0.1", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + "mock-model-four-huggingface": { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + } + model_one = JumpStartModelInfo("mock-model-one-huggingface", "1.2.3") + model_two = JumpStartModelInfo("mock-model-two-pytorch", "1.0.2") + # No model_one, older model_two + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) + assert res == [model_one, model_two] + + js_model_map = { + "mock-model-two-pytorch": { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.3", + ], + }, + "mock-model-four-huggingface": { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + } + # No model_one, newer model_two + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) + assert res == [model_one] + + js_model_map = { + "mock-model-one-huggingface": { + "name": "mock-model-one-huggingface", + "version": "1.2.3", + "search_keywords": [ + "@jumpstart-model-id:model-one-huggingface", + "@jumpstart-model-version:1.2.3", + ], + }, + "mock-model-two-pytorch": { + "name": "mock-model-two-pytorch", + "version": "1.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + } + # Same model_one, same model_two + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) + assert res == [] + + js_model_map = { + "mock-model-one-huggingface": { + "name": "mock-model-one-huggingface", + "version": "1.2.1", + "search_keywords": [ + "@jumpstart-model-id:model-one-huggingface", + "@jumpstart-model-version:1.2.1", + ], + }, + "mock-model-two-pytorch": { + "name": "mock-model-two-pytorch", + "version": "1.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + } + # Old model_one, same model_two + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) + assert res == [model_one] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py new file mode 100644 index 0000000000..accd2a5c8d --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import Mock, patch +from sagemaker.jumpstart.curated_hub.accessors.file_generator import ( + generate_file_infos_from_model_specs, + generate_file_infos_from_s3_location, +) +from sagemaker.jumpstart.curated_hub.types import FileInfo, S3ObjectLocation + +from sagemaker.jumpstart.types import JumpStartModelSpecs +from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec + + +@pytest.fixture() +def s3_client(): + mock_s3_client = Mock() + mock_s3_client.list_objects_v2.return_value = { + "Contents": [ + {"Key": "my-key-one", "Size": 123456789, "LastModified": "08-14-1997 00:00:00"} + ] + } + mock_s3_client.head_object.return_value = { + "ContentLength": 123456789, + "LastModified": "08-14-1997 00:00:00", + } + return mock_s3_client + + +def test_s3_path_file_generator_happy_path(s3_client): + s3_client.list_objects_v2.return_value = { + "Contents": [ + {"Key": "my-key-one", "Size": 123456789, "LastModified": "08-14-1997 00:00:00"}, + {"Key": "my-key-one", "Size": 10101010, "LastModified": "08-14-1997 00:00:00"}, + ] + } + + mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") + response = generate_file_infos_from_s3_location(mock_hub_bucket, s3_client) + + s3_client.list_objects_v2.assert_called_once() + assert response == [ + FileInfo("mock-bucket", "my-key-one", 123456789, "08-14-1997 00:00:00"), + FileInfo("mock-bucket", "my-key-one", 10101010, "08-14-1997 00:00:00"), + ] + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_model_specs_file_generator_happy_path(patched_get_model_specs, s3_client): + patched_get_model_specs.side_effect = get_spec_from_base_spec + + specs = JumpStartModelSpecs(BASE_SPEC) + studio_specs = {"defaultDataKey": "model_id123"} + response = generate_file_infos_from_model_specs(specs, studio_specs, "us-west-2", s3_client) + + s3_client.head_object.assert_called() + patched_get_model_specs.assert_called() + + assert response == [ + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo("jumpstart-cache-prod-us-west-2", "model_id123", 123456789, "08-14-1997 00:00:00"), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-notebooks/pytorch-ic-mobilenet-v2-inference.ipynb", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-metadata/pytorch-ic-mobilenet-v2.md", + 123456789, + "08-14-1997 00:00:00", + ), + ] + + +def test_s3_path_file_generator_with_no_objects(s3_client): + s3_client.list_objects_v2.return_value = {"Contents": []} + + mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") + response = generate_file_infos_from_s3_location(mock_hub_bucket, s3_client) + + s3_client.list_objects_v2.assert_called_once() + assert response == [] + + s3_client.list_objects_v2.reset_mock() + + s3_client.list_objects_v2.return_value = {} + + mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") + response = generate_file_infos_from_s3_location(mock_hub_bucket, s3_client) + + s3_client.list_objects_v2.assert_called_once() + assert response == [] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py new file mode 100644 index 0000000000..4753820320 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import unittest +from datetime import datetime +from sagemaker.jumpstart.curated_hub.types import FileInfo + +from sagemaker.jumpstart.curated_hub.sync.comparator import SizeAndLastUpdatedComparator + + +class SizeAndLastUpdateComparatorTest(unittest.TestCase): + comparator = SizeAndLastUpdatedComparator() + + def test_identical_files_returns_false(self): + file_one = FileInfo("bucket", "my-file-one", 123456789, datetime.today()) + file_two = FileInfo("bucket", "my-file-two", 123456789, datetime.today()) + + assert self.comparator.determine_should_sync(file_one, file_two) is False + + def test_different_file_sizes_returns_true(self): + file_one = FileInfo("bucket", "my-file-one", 123456789, datetime.today()) + file_two = FileInfo("bucket", "my-file-two", 10101010, datetime.today()) + + assert self.comparator.determine_should_sync(file_one, file_two) is True + + def test_different_file_dates_returns_true(self): + # change ordering of datetime.today() calls to trigger update + file_two = FileInfo("bucket", "my-file-two", 123456789, datetime.today()) + file_one = FileInfo("bucket", "my-file-one", 123456789, datetime.today()) + + assert self.comparator.determine_should_sync(file_one, file_two) is True diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py new file mode 100644 index 0000000000..497f0f4d60 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py @@ -0,0 +1,138 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from typing import List, Optional +from datetime import datetime + +import pytest +from sagemaker.jumpstart.curated_hub.sync.comparator import SizeAndLastUpdatedComparator + +from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequestFactory +from sagemaker.jumpstart.curated_hub.types import ( + FileInfo, + HubContentDependencyType, + S3ObjectLocation, +) + +COMPARATOR = SizeAndLastUpdatedComparator() + + +def _helper_generate_fileinfos( + num_infos: int, + bucket: Optional[str] = None, + key_prefix: Optional[str] = None, + size: Optional[int] = None, + last_updated: Optional[datetime] = None, + dependecy_type: Optional[HubContentDependencyType] = None, +) -> List[FileInfo]: + + file_infos = [] + for i in range(num_infos): + bucket = bucket or "default-bucket" + key_prefix = key_prefix or "mock-key" + size = size or 123456 + last_updated = last_updated or datetime.today() + + file_infos.append( + FileInfo( + bucket=bucket, + key=f"{key_prefix}-{i}", + size=size, + last_updated=last_updated, + dependecy_type={dependecy_type}, + ) + ) + return file_infos + + +@pytest.mark.parametrize( + ("src_files,dest_files"), + [ + pytest.param(_helper_generate_fileinfos(8), []), + pytest.param([], _helper_generate_fileinfos(8)), + pytest.param([], []), + ], +) +def test_sync_request_factory_edge_cases(src_files, dest_files): + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files + assert req.destination == dest_location + + +def test_passes_existing_files_in_dest(): + files = _helper_generate_fileinfos(4, key_prefix="aafile.py") + tarballs = _helper_generate_fileinfos(3, key_prefix="bb.tar.gz") + extra_files = _helper_generate_fileinfos(2, key_prefix="ccextrafiles.py") + + src_files = [*tarballs, *files, *extra_files] + dest_files = [files[1], files[2], tarballs[1]] + + expected_response = [files[0], files[3], tarballs[0], tarballs[2], *extra_files] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == expected_response + + +def test_adds_files_with_same_name_diff_size(): + file_one = _helper_generate_fileinfos(1, key_prefix="file.py", size=101010)[0] + file_two = _helper_generate_fileinfos(1, key_prefix="file.py", size=123456)[0] + + src_files = [file_one] + dest_files = [file_two] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files + + +def test_adds_files_with_same_name_dest_older_time(): + file_dest = _helper_generate_fileinfos(1, key_prefix="file.py", last_updated=datetime.today())[ + 0 + ] + file_src = _helper_generate_fileinfos(1, key_prefix="file.py", size=datetime.today())[0] + + src_files = [file_src] + dest_files = [file_dest] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files + + +def test_does_not_add_files_with_same_name_src_older_time(): + file_src = _helper_generate_fileinfos(1, key_prefix="file.py", last_updated=datetime.today())[0] + file_dest = _helper_generate_fileinfos(1, key_prefix="file.py", size=datetime.today())[0] + + src_files = [file_src] + dest_files = [file_dest] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 59a0a8f958..15b6d8fba3 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -139,10 +139,7 @@ def test_generate_hub_arn_for_init_kwargs(): utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) - assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) - == hub_arn - ) + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn def test_generate_default_hub_bucket_name(): @@ -171,4 +168,3 @@ def test_create_hub_bucket_if_it_does_not_exist(): mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() assert created_hub_bucket_name == bucket_name - assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 423dbf5e02..db1102efb2 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,7 +22,6 @@ from mock.mock import MagicMock import pytest from mock import patch -from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, @@ -874,7 +873,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - assert mocked_open.call_count == 2 + mocked_open.assert_not_called() mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 82e69e1d89..7f842a053c 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -407,8 +407,9 @@ def test_jumpstart_model_specs(): assert specs1.to_json() == BASE_SPEC - BASE_SPEC["model_id"] = "diff model ID" - specs2 = JumpStartModelSpecs(BASE_SPEC) + diff_specs = copy.deepcopy(BASE_SPEC) + diff_specs["model_id"] = "diff model ID" + specs2 = JumpStartModelSpecs(diff_specs) assert specs1 != specs2 specs3 = copy.deepcopy(specs1)