From 344d26bdc8be8c038ee2bbf1040b4c4346bc04b4 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 4 Mar 2024 23:10:39 +0000 Subject: [PATCH 01/15] first pass at sync function with util classes --- src/sagemaker/jumpstart/cache.py | 11 +- .../curated_hub/accessors/__init__.py | 0 .../curated_hub/accessors/filegenerator.py | 112 ++++++++++ .../curated_hub/accessors/fileinfo.py | 47 +++++ .../curated_hub/accessors/objectlocation.py | 46 ++++ .../accessors/public_model_data.py | 111 ++++++++++ .../jumpstart/curated_hub/accessors/sync.py | 74 +++++++ .../curated_hub/accessors/synccomparator.py | 72 +++++++ .../jumpstart/curated_hub/curated_hub.py | 199 +++++++++++++++++- src/sagemaker/jumpstart/curated_hub/types.py | 40 ++++ src/sagemaker/jumpstart/curated_hub/utils.py | 6 + .../jumpstart/curated_hub/test_curated_hub.py | 6 +- .../curated_hub/test_filegenerator.py | 124 +++++++++++ .../jumpstart/curated_hub/test_sync.py | 0 .../curated_hub/test_synccomparator.py | 41 ++++ .../jumpstart/curated_hub/test_utils.py | 6 +- tests/unit/sagemaker/jumpstart/test_cache.py | 3 +- 17 files changed, 881 insertions(+), 17 deletions(-) create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/__init__.py create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/sync.py create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py create mode 100644 src/sagemaker/jumpstart/curated_hub/types.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_sync.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 6cbc5b30cb..9d804dc53a 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -101,7 +101,8 @@ def __init__( Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object, - used for SageMaker interactions. Default: Session in region associated with boto3 session. + used for SageMaker interactions. Default: Session in region associated with boto3 + session. """ self._region = region @@ -358,7 +359,9 @@ def _retrieval_function( hub_content_type=data_type ) - model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True) + model_specs = JumpStartModelSpecs( + DescribeHubContentsResponse(hub_model_description), is_hub_content=True + ) utils.emit_logs_based_on_model_specs( model_specs, @@ -372,7 +375,9 @@ def _retrieval_function( hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) hub_description = DescribeHubResponse(response) - return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description)) + return JumpStartCachedContentValue( + formatted_content=DescribeHubResponse(hub_description) + ) raise ValueError( f"Bad value for key '{key}': must be in ", f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}" diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/__init__.py b/src/sagemaker/jumpstart/curated_hub/accessors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py new file mode 100644 index 0000000000..e54f219f1a --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains important utilities related to HubContent data files.""" +from __future__ import absolute_import +from functools import singledispatchmethod +from typing import Any, Dict, List, Optional + +from botocore.client import BaseClient + +from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType +from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation +from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor +from sagemaker.jumpstart.types import JumpStartModelSpecs + + +class FileGenerator: + """Utility class to help format HubContent data files.""" + + def __init__( + self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None + ): + self.region = region + self.s3_client = s3_client + self.studio_specs = studio_specs + + @singledispatchmethod + def format(self, file_input) -> List[FileInfo]: + """Implement.""" + # pylint: disable=W0107 + pass + + @format.register + def _(self, file_input: S3ObjectLocation) -> List[FileInfo]: + """Something.""" + files = self.s3_format(file_input) + return files + + @format.register + def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: + """Something.""" + files = self.specs_format(file_input, self.studio_specs) + return files + + def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: + """Retrieves data from a bucket and formats into FileInfo""" + parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key} + response = self.s3_client.list_objects_v2(**parameters) + contents = response.get("Contents", None) + + if not contents: + print("Nothing to download") + return [] + + files = [] + for s3_obj in contents: + key: str = s3_obj.get("Key") + size: bytes = s3_obj.get("Size", None) + last_modified: str = s3_obj.get("LastModified", None) + files.append(FileInfo(key, size, last_modified)) + return files + + def specs_format( + self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any] + ) -> List[FileInfo]: + """Collects data locations from JumpStart public model specs and + converts into FileInfo. + """ + public_model_data_accessor = PublicModelDataAccessor( + region=self.region, model_specs=file_input, studio_specs=studio_specs + ) + function_table = { + HubContentDependencyType.INFERENCE_ARTIFACT: ( + public_model_data_accessor.get_inference_artifact_s3_reference + ), + HubContentDependencyType.TRAINING_ARTIFACT: ( + public_model_data_accessor.get_training_artifact_s3_reference + ), + HubContentDependencyType.INFERNECE_SCRIPT: ( + public_model_data_accessor.get_inference_script_s3_reference + ), + HubContentDependencyType.TRAINING_SCRIPT: ( + public_model_data_accessor.get_training_script_s3_reference + ), + HubContentDependencyType.DEFAULT_TRAINING_DATASET: ( + public_model_data_accessor.get_default_training_dataset_s3_reference + ), + HubContentDependencyType.DEMO_NOTEBOOK: ( + public_model_data_accessor.get_demo_notebook_s3_reference + ), + HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference, + } + files = [] + for dependency in HubContentDependencyType: + location = function_table[dependency]() + parameters = {"Bucket": location.bucket, "Prefix": location.key} + response = self.s3_client.head_object(**parameters) + key: str = location.key + size: bytes = response.get("ContentLength", None) + last_updated: str = response.get("LastModified", None) + dependency_type: HubContentDependencyType = dependency + files.append(FileInfo(key, size, last_updated, dependency_type)) + return files diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py b/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py new file mode 100644 index 0000000000..2407560a9e --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains important details related to HubContent data files.""" +from __future__ import absolute_import + +from enum import Enum +from dataclasses import dataclass +from typing import Optional + + +class HubContentDependencyType(str, Enum): + """Enum class for HubContent dependency names""" + + INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT" + TRAINING_ARTIFACT = "TRAINING_ARTIFACT" + INFERNECE_SCRIPT = "INFERENCE_SCRIPT" + TRAINING_SCRIPT = "TRAINING_SCRIPT" + DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET" + DEMO_NOTEBOOK = "DEMO_NOTEBOOK" + MARKDOWN = "MARKDOWN" + + +@dataclass +class FileInfo: + """Data class for additional S3 file info.""" + + def __init__( + self, + name: str, + size: Optional[bytes], + last_updated: Optional[str], + dependecy_type: Optional[HubContentDependencyType] = None, + ): + self.name = name + self.size = size + self.last_updated = last_updated + self.dependecy_type = dependecy_type diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py b/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py new file mode 100644 index 0000000000..b0512bed2e --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module utilites to assist S3 client calls for the Curated Hub.""" +from __future__ import absolute_import +from dataclasses import dataclass +from typing import Dict + + +@dataclass +class S3ObjectLocation: + """Helper class for S3 object references""" + + bucket: str + key: str + + def format_for_s3_copy(self) -> Dict[str, str]: + """Returns a dict formatted for S3 copy calls""" + return { + "Bucket": self.bucket, + "Key": self.key, + } + + def get_uri(self) -> str: + """Returns the s3 URI""" + return f"s3://{self.bucket}/{self.key}" + + +def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: + """Utiity to help generate an S3 object reference""" + uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1) + uri_split = uri_with_s3_prefix_removed.split("/") + + return S3ObjectLocation( + bucket=uri_split[0], + key="/".join(uri_split[1:]) if len(uri_split) > 1 else "", + ) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py new file mode 100644 index 0000000000..c84c2127e9 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -0,0 +1,111 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module accessors for the SageMaker JumpStart Public Hub.""" +from __future__ import absolute_import +from typing import Dict, Any +from sagemaker import model_uris, script_uris +from sagemaker.jumpstart.curated_hub.utils import ( + get_model_framework, +) +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.types import JumpStartModelSpecs +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.curated_hub.accessors.objectlocation import ( + S3ObjectLocation, + create_s3_object_reference_from_uri, +) + + +class PublicModelDataAccessor: + """Accessor class for JumpStart model data s3 locations.""" + + def __init__( + self, + region: str, + model_specs: JumpStartModelSpecs, + studio_specs: Dict[str, Dict[str, Any]], + ): + self._region = region + self._bucket = get_jumpstart_content_bucket(region) + self.model_specs = model_specs + self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift + + def get_bucket_name(self) -> str: + """Retrieves s3 bucket""" + return self._bucket + + def get_inference_artifact_s3_reference(self) -> S3ObjectLocation: + """Retrieves s3 reference for model inference artifact""" + return create_s3_object_reference_from_uri( + self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) + ) + + def get_training_artifact_s3_reference(self) -> S3ObjectLocation: + """Retrieves s3 reference for model training artifact""" + return create_s3_object_reference_from_uri( + self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) + ) + + def get_inference_script_s3_reference(self) -> S3ObjectLocation: + """Retrieves s3 reference for model inference script""" + return create_s3_object_reference_from_uri( + self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) + ) + + def get_training_script_s3_reference(self) -> S3ObjectLocation: + """Retrieves s3 reference for model training script""" + return create_s3_object_reference_from_uri( + self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) + ) + + def get_default_training_dataset_s3_reference(self) -> S3ObjectLocation: + """Retrieves s3 reference for s3 directory containing model training datasets""" + return S3ObjectLocation(self.get_bucket_name(), self._get_training_dataset_prefix()) + + def _get_training_dataset_prefix(self) -> str: + """Retrieves training dataset location""" + return self.studio_specs["defaultDataKey"] + + def get_demo_notebook_s3_reference(self) -> S3ObjectLocation: + """Retrieves s3 reference for model demo jupyter notebook""" + framework = get_model_framework(self.model_specs) + key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" + return S3ObjectLocation(self.get_bucket_name(), key) + + def get_markdown_s3_reference(self) -> S3ObjectLocation: + """Retrieves s3 reference for model markdown""" + framework = get_model_framework(self.model_specs) + key = f"{framework}-metadata/{self.model_specs.model_id}.md" + return S3ObjectLocation(self.get_bucket_name(), key) + + def _jumpstart_script_s3_uri(self, model_scope: str) -> str: + """Retrieves JumpStart script s3 location""" + return script_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + script_scope=model_scope, + tolerate_vulnerable_model=True, + tolerate_deprecated_model=True, + ) + + def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: + """Retrieves JumpStart artifact s3 location""" + return model_uris.retrieve( + region=self._region, + model_id=self.model_specs.model_id, + model_version=self.model_specs.version, + model_scope=model_scope, + tolerate_vulnerable_model=True, + tolerate_deprecated_model=True, + ) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py new file mode 100644 index 0000000000..f3be606632 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py @@ -0,0 +1,74 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides a class to help copy HubContent dependencies.""" +from __future__ import absolute_import +from typing import Generator, List + +from botocore.compat import six +from sagemaker.jumpstart.curated_hub.accessors.synccomparator import SizeAndLastUpdatedComparator + +from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo + +advance_iterator = six.advance_iterator + + +class FileSync: + """Something.""" + + def __init__(self, src_files: List[FileInfo], dest_files: List[FileInfo], dest_bucket: str): + """Instantiates a ``FileSync`` class. + Sorts src and dest files by name for easier + comparisons. + + Args: + src_files (List[FileInfo]): List of files to sync with destination + dest_files (List[FileInfo]): List of files already in destination bucket + dest_bucket (str): Destination bucket name for copied data + """ + self.comparator = SizeAndLastUpdatedComparator() + self.src_files: List[FileInfo] = sorted(src_files, lambda x: x.name) + self.dest_files: List[FileInfo] = sorted(dest_files, lambda x: x.name) + self.dest_bucket = dest_bucket + + def call(self) -> Generator[FileInfo, FileInfo, FileInfo]: + """This function performs the actual comparisons. Returns a list of FileInfo + to copy. + + Algorithm: + Loop over sorted files in the src directory. If at the end of dest directory, + add the src file to be copied and continue. Else, take the first file in the sorted + dest directory. If there is no dest file, signal we're at the end of dest and continue. + If there is a dest file, compare file names. If the file names are equivalent, + use the comparator to see if the dest file should be updated. If the file + names are not equivalent, increment the dest pointer. + """ + # :var dest_done: True if there are no files from the dest left. + dest_done = False + for src_file in self.src_files: + if dest_done: + yield src_file + continue + + while not dest_done: + try: + dest_file: FileInfo = advance_iterator(self.dest_files) + except StopIteration: + dest_done = True + break + + if src_file.name == dest_file.name: + should_sync = self.comparator.determine_should_sync(src_file, dest_file) + + if should_sync: + yield src_file + break diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py new file mode 100644 index 0000000000..8a478b1458 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides comparators for syncing s3 files.""" +from __future__ import absolute_import + +from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo + + +class SizeAndLastUpdatedComparator: + """Something.""" + + def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool: + """Determines if src file should be moved to dest folder.""" + same_size = self.compare_size(src_file, dest_file) + same_last_modified_time = self.compare_time(src_file, dest_file) + should_sync = (not same_size) or (not same_last_modified_time) + if should_sync: + print( + "syncing: %s -> %s, size: %s -> %s, modified time: %s -> %s", + src_file.name, + src_file.name, + src_file.size, + dest_file.size, + src_file.last_updated, + dest_file.last_updated, + ) + return should_sync + + def total_seconds(self, td): + """ + timedelta's time_seconds() function for python 2.6 users + + :param td: The difference between two datetime objects. + """ + return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6 + + def compare_size(self, src_file: FileInfo, dest_file: FileInfo): + """ + :returns: True if the sizes are the same. + False otherwise. + """ + return src_file.size == dest_file.size + + def compare_time(self, src_file: FileInfo, dest_file: FileInfo): + """ + :returns: True if the file does not need updating based on time of + last modification and type of operation. + False if the file does need updating based on the time of + last modification and type of operation. + """ + src_time = src_file.last_updated + dest_time = dest_file.last_updated + delta = dest_time - src_time + # pylint: disable=R1703,R1705 + if self.total_seconds(delta) >= 0: + # Destination is newer than source. + return True + else: + # Destination is older than source, so + # we have a more recently updated file + # at the source location. + return False diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 59a11df577..9dd9bbe330 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -12,9 +12,18 @@ # language governing permissions and limitations under the License. """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import +from concurrent import futures +import traceback +from typing import Optional, Dict, List, Any +import boto3 +from botocore.client import BaseClient -from typing import Any, Dict, Optional +from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator +from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation +from sagemaker.jumpstart.curated_hub.accessors.sync import FileSync +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from sagemaker.session import Session from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.types import ( @@ -22,7 +31,11 @@ DescribeHubContentsResponse, HubContentType, ) -from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist +from sagemaker.jumpstart.curated_hub.utils import ( + create_hub_bucket_if_it_does_not_exist, + generate_default_hub_bucket_name, +) +from sagemaker.jumpstart.curated_hub.types import HubContentDocument_v2 class CuratedHub: @@ -31,8 +44,9 @@ class CuratedHub: def __init__( self, hub_name: str, + bucket_name: Optional[str] = None, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - ): + ) -> None: """Instantiates a SageMaker ``CuratedHub``. Args: @@ -43,18 +57,37 @@ def __init__( self.hub_name = hub_name self.region = sagemaker_session.boto_region_name self._sagemaker_session = sagemaker_session + self._default_thread_pool_size = 20 + self.hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() + self._s3_client = self._get_s3_client() + + def _get_s3_client(self) -> BaseClient: + """Returns an S3 client.""" + return boto3.client("s3", region_name=self.region) + + def _fetch_hub_bucket_name(self) -> str: + """Retrieves hub bucket name from Hub config if exists""" + try: + hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) + hub_bucket_prefix = hub_response["S3StorageConfig"]["S3OutputPath"] + return hub_bucket_prefix # TODO: Strip s3:// prefix + except ValueError: + hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + print(f"Hub bucket name is: {hub_bucket_name}") # TODO: Better messaging + return hub_bucket_name def create( self, description: str, display_name: Optional[str] = None, search_keywords: Optional[str] = None, - bucket_name: Optional[str] = None, tags: Optional[str] = None, ) -> Dict[str, str]: """Creates a hub with the given description""" - bucket_name = create_hub_bucket_if_it_does_not_exist(bucket_name, self._sagemaker_session) + bucket_name = create_hub_bucket_if_it_does_not_exist( + self.hub_bucket_name, self._sagemaker_session + ) return self._sagemaker_session.create_hub( hub_name=self.hub_name, @@ -113,3 +146,159 @@ def delete_model(self, model_name: str, model_version: str = "*") -> None: def delete(self) -> None: """Deletes this Curated Hub""" return self._sagemaker_session.delete_hub(self.hub_name) + + def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool: + """Determines if input args to ``sync`` is correct.""" + for obj in model_list: + if not isinstance(obj.get("model_id"), str): + return True + if "version" in obj and not isinstance(obj["version"], str): + return True + return False + + def sync(self, model_list: List[Dict[str, str]]): + """Syncs a list of JumpStart model ids and versions with a CuratedHub + + Args: + model_list (List[Dict[str, str]]): List of `{ model_id: str, version: Optional[str] }` + objects that should be synced into the Hub. + """ + if self._is_invalid_model_list_input(model_list): + raise ValueError( + "Model list should be a list of objects with values 'model_id',", + "and optional 'version'.", + ) + + # Fetch required information + # self._get_studio_manifest_map() + hub_models = self.list_models() + + # Retrieve latest version of unspecified JumpStart model versions + model_version_list = [] + for model in model_list: + # TODO: Uncomment and implement + # if not model["version"] or model["version"] == "*": + # model["version"] = self._find_latest_version(model_name=model.model_id) + model_version_list.append(model) + + # Find synced JumpStart model versions in the Hub + js_models_in_hub = [] + for model in hub_models: + # TODO: extract both in one pass + jumpstart_model_id = next( + (tag for tag in model.search_keywords if tag.startswith("@jumpstart-model-id")), + None, + ) + jumpstart_model_version = next( + ( + tag + for tag in model.search_keywords + if tag.startswith("@jumpstart-model-version") + ), + None, + ) + + if jumpstart_model_id and jumpstart_model_version: + js_models_in_hub.append(model) + + # Match inputted list of model versions with synced JumpStart model versions in the Hub + models_to_sync = [] + for model in model_version_list: + model_id, version = model + matched_model = next((model for model in js_models_in_hub if model.name == model_id)) + + # Model does not exist in Hub, sync + if not matched_model: + models_to_sync.append(model) + + if matched_model: + # 1. Model version exists in Hub, pass + if matched_model.version == version: + pass + + # 2. Invalid model version exists in Hub, pass + # This will only happen if something goes wrong in our metadata + if matched_model.version > version: + pass + + # 3. Old model version exists in Hub, update + if matched_model.version < version: + # Check minSDKVersion against current SDK version, emit log + models_to_sync.append(model) + + # Delete old models? + + # Copy content workflow + `SageMaker:ImportHubContent` for each model-to-sync in parallel + tasks: List[futures.Future] = [] + with futures.ThreadPoolExecutor( + max_workers=self._default_thread_pool_size, + thread_name_prefix="import-models-to-curated-hub", + ) as deploy_executor: + for model in models_to_sync: + task = deploy_executor.submit(self._sync_public_model_to_hub, model) + tasks.append(task) + + # Handle failed imports + results = futures.wait(tasks) + failed_imports: List[Dict[str, Any]] = [] + for result in results.done: + exception = result.exception() + if exception: + failed_imports.append( + { + "Exception": exception, + "Traceback": "".join( + traceback.TracebackException.from_exception(exception).format() + ), + } + ) + if failed_imports: + raise RuntimeError( + f"Failures when importing models to curated hub in parallel: {failed_imports}" + ) + + def _sync_public_model_to_hub(self, model: Dict[str, str]): + """Syncs a public JumpStart model version to the Hub. Runs in parallel.""" + model_name = model["name"] + model_version = model["version"] + + model_specs = verify_model_region_and_return_specs( + model_id=model_name, + version=model_version, + region=self.region, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self._sagemaker_session, + ) + # TODO: Uncomment and implement + # studio_specs = self.fetch_studio_specs(model_id=model_name, version=model_version) + studio_specs = {} + + dest_location = S3ObjectLocation( + bucket=self.hub_bucket_name, key=f"{model_name}/{model_version}" + ) + # TODO: Validations? HeadBucket? + + file_generator = FileGenerator(self.region, self._s3_client, studio_specs) + src_files = file_generator.format(model_specs) + dest_files = file_generator.format(dest_location) + + files_to_copy = list(FileSync(src_files, dest_files, dest_location).call()) + + if len(files_to_copy) > 0: + # Copy files with MPU + print("hi") + + hub_content_document = HubContentDocument_v2(spec=model_specs) + + self._sagemaker_session.import_hub_content( + document_schema_version=HubContentDocument_v2.SCHEMA_VERSION, + hub_content_name=model_name, + hub_content_version=model_version, + hub_name=self.hub_name, + hub_content_document=hub_content_document, + hub_content_type=HubContentType.MODEL, + hub_content_display_name="", + hub_content_description="", + hub_content_markdown="", + hub_content_search_keywords=[], + ) diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py new file mode 100644 index 0000000000..4de41d9df7 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -0,0 +1,40 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart.""" +from __future__ import absolute_import +from typing import Dict, Any + +from sagemaker.jumpstart.types import JumpStartDataHolderType + + +class HubContentDocument_v2(JumpStartDataHolderType): + """Data class for HubContentDocument v2.0.0""" + + SCHEMA_VERSION = "2.0.0" + + def __init__(self, spec: Dict[str, Any]): + """Initializes a HubContentDocument_v2 object from JumpStart model specs. + + Args: + spec (Dict[str, Any]): Dictionary representation of spec. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representatino of spec. + """ + # TODO: Implement + self.Url: str = json_obj["url"] diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index ac01da45ca..c425d05485 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart.types import ( HubContentType, HubArnExtractedInfo, + JumpStartModelSpecs, ) from sagemaker.jumpstart import constants @@ -152,3 +153,8 @@ def create_hub_bucket_if_it_does_not_exist( ) return bucket_name + + +def get_model_framework(model_specs: JumpStartModelSpecs) -> str: + """Retrieves the model framework from a model spec""" + return model_specs.model_id.split("-")[0] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 2448721520..a2eabddbb8 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -29,6 +29,9 @@ def sagemaker_session(): sagemaker_session_mock._client_config.user_agent = ( "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" ) + sagemaker_session_mock.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": "mock-bucket-123"} + } sagemaker_session_mock.account_id.return_value = ACCOUNT_ID return sagemaker_session_mock @@ -65,6 +68,7 @@ def test_create_with_no_bucket_name( ): create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) + sagemaker_session.describe_hub.return_value = {"S3StorageConfig": {"S3OutputPath": None}} hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, @@ -77,7 +81,6 @@ def test_create_with_no_bucket_name( response = hub.create( description=hub_description, display_name=hub_display_name, - bucket_name=hub_bucket_name, search_keywords=hub_search_keywords, tags=tags, ) @@ -122,7 +125,6 @@ def test_create_with_bucket_name( response = hub.create( description=hub_description, display_name=hub_display_name, - bucket_name=hub_bucket_name, search_keywords=hub_search_keywords, tags=tags, ) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py new file mode 100644 index 0000000000..65084c10b3 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -0,0 +1,124 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import Mock, patch +from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator +from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo + +from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation +from sagemaker.jumpstart.types import JumpStartModelSpecs +from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec + + +@pytest.fixture() +def s3_client(): + mock_s3_client = Mock() + mock_s3_client.list_objects_v2.return_value = { + "Contents": [ + {"Key": "my-key-one", "Size": 123456789, "LastModified": "08-14-1997 00:00:00"} + ] + } + mock_s3_client.head_object.return_value = { + "ContentLength": 123456789, + "LastModified": "08-14-1997 00:00:00", + } + return mock_s3_client + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_format_has_two_different_input_pathways(patched_get_model_specs, s3_client): + patched_get_model_specs.side_effect = get_spec_from_base_spec + studio_specs = {"defaultDataKey": "model_id123"} + mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") + generator = FileGenerator("us-west-2", s3_client, studio_specs) + generator.format(mock_hub_bucket) + + s3_client.list_objects_v2.assert_called_once() + s3_client.head_object.assert_not_called() + patched_get_model_specs.assert_not_called() + + # Other mocks shouldn't have been called + s3_client.list_objects_v2.reset_mock() + + specs = JumpStartModelSpecs(BASE_SPEC) + generator.format(specs) + + s3_client.list_objects_v2.assert_not_called() + s3_client.head_object.assert_called() + patched_get_model_specs.assert_called() + + +def test_object_location_input_works(s3_client): + s3_client.list_objects_v2.return_value = { + "Contents": [ + {"Key": "my-key-one", "Size": 123456789, "LastModified": "08-14-1997 00:00:00"}, + {"Key": "my-key-one", "Size": 10101010, "LastModified": "08-14-1997 00:00:00"}, + ] + } + + mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") + generator = FileGenerator("us-west-2", s3_client) + response = generator.format(mock_hub_bucket) + + s3_client.list_objects_v2.assert_called_once() + assert response == [ + FileInfo("my-key-one", 123456789, "08-14-1997 00:00:00"), + FileInfo("my-key-one", 10101010, "08-14-1997 00:00:00"), + ] + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_specs_input_works(patched_get_model_specs, s3_client): + patched_get_model_specs.side_effect = get_spec_from_base_spec + + specs = JumpStartModelSpecs(BASE_SPEC) + studio_specs = {"defaultDataKey": "model_id123"} + generator = FileGenerator("us-west-2", s3_client, studio_specs) + response = generator.format(specs) + + s3_client.head_object.assert_called() + patched_get_model_specs.assert_called() + # TODO: Figure out why object attrs aren't being compared + assert response == [ + FileInfo("my-key-one", 123456789, "08-14-1997 00:00:00"), + FileInfo("my-key-two", 123456789, "08-14-1997 00:00:00"), + FileInfo("my-key-three", 123456789, "08-14-1997 00:00:00"), + FileInfo("my-key-four", 123456789, "08-14-1997 00:00:00"), + FileInfo("my-key-five", 123456789, "08-14-1997 00:00:00"), + FileInfo("my-key-six", 123456789, "08-14-1997 00:00:00"), + FileInfo("my-key-seven", 123456789, "08-14-1997 00:00:00"), + ] + + +def test_object_location_no_objects(s3_client): + s3_client.list_objects_v2.return_value = {"Contents": []} + + mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") + generator = FileGenerator("us-west-2", s3_client) + response = generator.format(mock_hub_bucket) + + s3_client.list_objects_v2.assert_called_once() + assert response == [] + + s3_client.list_objects_v2.reset_mock() + + s3_client.list_objects_v2.return_value = {} + + mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") + generator = FileGenerator("us-west-2", s3_client) + response = generator.format(mock_hub_bucket) + + s3_client.list_objects_v2.assert_called_once() + assert response == [] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_sync.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_sync.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py new file mode 100644 index 0000000000..74f7aafc9a --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import unittest +from datetime import datetime +from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo + +from sagemaker.jumpstart.curated_hub.accessors.synccomparator import SizeAndLastUpdatedComparator + + +class SizeAndLastUpdateComparatorTest(unittest.TestCase): + comparator = SizeAndLastUpdatedComparator() + + def test_identical_files_returns_false(self): + file_one = FileInfo("my-file-one", 123456789, datetime.today()) + file_two = FileInfo("my-file-two", 123456789, datetime.today()) + + assert self.comparator.determine_should_sync(file_one, file_two) is False + + def test_different_file_sizes_returns_true(self): + file_one = FileInfo("my-file-one", 123456789, datetime.today()) + file_two = FileInfo("my-file-two", 10101010, datetime.today()) + + assert self.comparator.determine_should_sync(file_one, file_two) is True + + def test_different_file_dates_returns_true(self): + # change ordering of datetime.today() calls to trigger update + file_two = FileInfo("my-file-two", 123456789, datetime.today()) + file_one = FileInfo("my-file-one", 123456789, datetime.today()) + + assert self.comparator.determine_should_sync(file_one, file_two) is True diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 59a0a8f958..15b6d8fba3 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -139,10 +139,7 @@ def test_generate_hub_arn_for_init_kwargs(): utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) - assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) - == hub_arn - ) + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn def test_generate_default_hub_bucket_name(): @@ -171,4 +168,3 @@ def test_create_hub_bucket_if_it_does_not_exist(): mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() assert created_hub_bucket_name == bucket_name - assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 423dbf5e02..d87a2f957b 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,7 +22,6 @@ from mock.mock import MagicMock import pytest from mock import patch -from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, @@ -874,7 +873,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - assert mocked_open.call_count == 2 + assert mocked_open.call_count == 2 # TODO: ?? mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), From 374c6389e5453e48258bb7f3bcfcfcf1e72fbbfb Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 5 Mar 2024 17:20:51 +0000 Subject: [PATCH 02/15] adding tests and update clases --- .../curated_hub/accessors/filegenerator.py | 43 ++---- .../curated_hub/accessors/fileinfo.py | 14 +- .../accessors/public_model_data.py | 64 ++++++-- .../jumpstart/curated_hub/accessors/sync.py | 2 +- .../curated_hub/accessors/synccomparator.py | 11 +- .../jumpstart/curated_hub/curated_hub.py | 52 +++++-- .../jumpstart/curated_hub/test_curated_hub.py | 143 ++++++++++++++++++ tests/unit/sagemaker/jumpstart/test_cache.py | 2 +- 8 files changed, 260 insertions(+), 71 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py index e54f219f1a..b8a6b299e8 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py @@ -35,19 +35,28 @@ def __init__( @singledispatchmethod def format(self, file_input) -> List[FileInfo]: - """Implement.""" + """Dispatch method that takes in an input of either ``S3ObjectLocation`` or + ``JumpStartModelSpecs`` and is implemented in below registered functions. + """ # pylint: disable=W0107 pass @format.register def _(self, file_input: S3ObjectLocation) -> List[FileInfo]: - """Something.""" + """Implements ``.format`` when the input is of type ``S3ObjectLocation``. + + Returns a list of ``FileInfo`` objects from the specified bucket location. + """ files = self.s3_format(file_input) return files @format.register def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: - """Something.""" + """Implements ``.format`` when the input is of type ``JumpStartModelSpecs``. + + Returns a list of ``FileInfo`` objects from dependencies found in the public + model specs. + """ files = self.specs_format(file_input, self.studio_specs) return files @@ -72,36 +81,16 @@ def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: def specs_format( self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any] ) -> List[FileInfo]: - """Collects data locations from JumpStart public model specs and - converts into FileInfo. + """ + Collects data locations from JumpStart public model specs and + converts into FileInfo. """ public_model_data_accessor = PublicModelDataAccessor( region=self.region, model_specs=file_input, studio_specs=studio_specs ) - function_table = { - HubContentDependencyType.INFERENCE_ARTIFACT: ( - public_model_data_accessor.get_inference_artifact_s3_reference - ), - HubContentDependencyType.TRAINING_ARTIFACT: ( - public_model_data_accessor.get_training_artifact_s3_reference - ), - HubContentDependencyType.INFERNECE_SCRIPT: ( - public_model_data_accessor.get_inference_script_s3_reference - ), - HubContentDependencyType.TRAINING_SCRIPT: ( - public_model_data_accessor.get_training_script_s3_reference - ), - HubContentDependencyType.DEFAULT_TRAINING_DATASET: ( - public_model_data_accessor.get_default_training_dataset_s3_reference - ), - HubContentDependencyType.DEMO_NOTEBOOK: ( - public_model_data_accessor.get_demo_notebook_s3_reference - ), - HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference, - } files = [] for dependency in HubContentDependencyType: - location = function_table[dependency]() + location = public_model_data_accessor.get_s3_reference(dependency) parameters = {"Bucket": location.bucket, "Prefix": location.key} response = self.s3_client.head_object(**parameters) key: str = location.key diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py b/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py index 2407560a9e..d2091215fb 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py @@ -21,13 +21,13 @@ class HubContentDependencyType(str, Enum): """Enum class for HubContent dependency names""" - INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT" - TRAINING_ARTIFACT = "TRAINING_ARTIFACT" - INFERNECE_SCRIPT = "INFERENCE_SCRIPT" - TRAINING_SCRIPT = "TRAINING_SCRIPT" - DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET" - DEMO_NOTEBOOK = "DEMO_NOTEBOOK" - MARKDOWN = "MARKDOWN" + INFERENCE_ARTIFACT = "inference_artifact_s3_reference" + TRAINING_ARTIFACT = "training_artifact_s3_reference" + INFERENCE_SCRIPT = "inference_script_s3_reference" + TRAINING_SCRIPT = "training_script_s3_reference" + DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference" + DEMO_NOTEBOOK = "demo_notebook_s3_reference" + MARKDOWN = "markdown_s3_reference" @dataclass diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index c84c2127e9..b1438f86dd 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Dict, Any from sagemaker import model_uris, script_uris +from sagemaker.jumpstart.curated_hub.accessors.fileinfo import HubContentDependencyType from sagemaker.jumpstart.curated_hub.utils import ( get_model_framework, ) @@ -40,53 +41,92 @@ def __init__( self.model_specs = model_specs self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift - def get_bucket_name(self) -> str: + def get_s3_reference(self, dependency_type: HubContentDependencyType): + """Retrieves S3 reference given a HubContentDependencyType.""" + return getattr(self, dependency_type.value) + + @property + def inference_artifact_s3_reference(self): + """Retrieves s3 reference for model inference artifact""" + return self._get_inference_artifact_s3_reference() + + @property + def training_artifact_s3_reference(self): + """Retrieves s3 reference for model training artifact""" + return self._get_training_artifact_s3_reference() + + @property + def inference_script_s3_reference(self): + """Retrieves s3 reference for model inference script""" + return self._get_inference_script_s3_reference() + + @property + def training_script_s3_reference(self): + """Retrieves s3 reference for model training script""" + return self._get_training_script_s3_reference() + + @property + def default_training_dataset_s3_reference(self): + """Retrieves s3 reference for s3 directory containing model training datasets""" + return self._get_default_training_dataset_s3_reference() + + @property + def demo_notebook_s3_reference(self): + """Retrieves s3 reference for model demo jupyter notebook""" + return self._get_demo_notebook_s3_reference() + + @property + def markdown_s3_reference(self): + """Retrieves s3 reference for model markdown""" + return self._get_markdown_s3_reference() + + def _get_bucket_name(self) -> str: """Retrieves s3 bucket""" return self._bucket - def get_inference_artifact_s3_reference(self) -> S3ObjectLocation: + def _get_inference_artifact_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model inference artifact""" return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) ) - def get_training_artifact_s3_reference(self) -> S3ObjectLocation: + def _get_training_artifact_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model training artifact""" return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) ) - def get_inference_script_s3_reference(self) -> S3ObjectLocation: + def _get_inference_script_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model inference script""" return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) ) - def get_training_script_s3_reference(self) -> S3ObjectLocation: + def _get_training_script_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model training script""" return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) ) - def get_default_training_dataset_s3_reference(self) -> S3ObjectLocation: + def _get_default_training_dataset_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for s3 directory containing model training datasets""" - return S3ObjectLocation(self.get_bucket_name(), self._get_training_dataset_prefix()) + return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) - def _get_training_dataset_prefix(self) -> str: + def __get_training_dataset_prefix(self) -> str: """Retrieves training dataset location""" return self.studio_specs["defaultDataKey"] - def get_demo_notebook_s3_reference(self) -> S3ObjectLocation: + def _get_demo_notebook_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model demo jupyter notebook""" framework = get_model_framework(self.model_specs) key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" - return S3ObjectLocation(self.get_bucket_name(), key) + return S3ObjectLocation(self._get_bucket_name(), key) - def get_markdown_s3_reference(self) -> S3ObjectLocation: + def _get_markdown_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model markdown""" framework = get_model_framework(self.model_specs) key = f"{framework}-metadata/{self.model_specs.model_id}.md" - return S3ObjectLocation(self.get_bucket_name(), key) + return S3ObjectLocation(self._get_bucket_name(), key) def _jumpstart_script_s3_uri(self, model_scope: str) -> str: """Retrieves JumpStart script s3 location""" diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py index f3be606632..8c642d222e 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module provides a class to help copy HubContent dependencies.""" +"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" from __future__ import absolute_import from typing import Generator, List diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py index 8a478b1458..bfbae6b617 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module provides comparators for syncing s3 files.""" from __future__ import absolute_import +from datetime import timedelta from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo @@ -36,14 +37,6 @@ def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool ) return should_sync - def total_seconds(self, td): - """ - timedelta's time_seconds() function for python 2.6 users - - :param td: The difference between two datetime objects. - """ - return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6 - def compare_size(self, src_file: FileInfo, dest_file: FileInfo): """ :returns: True if the sizes are the same. @@ -62,7 +55,7 @@ def compare_time(self, src_file: FileInfo, dest_file: FileInfo): dest_time = dest_file.last_updated delta = dest_time - src_time # pylint: disable=R1703,R1705 - if self.total_seconds(delta) >= 0: + if timedelta.total_seconds(delta) >= 0: # Destination is newer than source. return True else: diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 9dd9bbe330..44a7c0ca9d 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -18,6 +18,7 @@ import boto3 from botocore.client import BaseClient +from packaging.version import Version from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation @@ -176,53 +177,69 @@ def sync(self, model_list: List[Dict[str, str]]): # Retrieve latest version of unspecified JumpStart model versions model_version_list = [] for model in model_list: - # TODO: Uncomment and implement - # if not model["version"] or model["version"] == "*": - # model["version"] = self._find_latest_version(model_name=model.model_id) + version = model.get("version", "*") + if not version or version == "*": + model_specs = verify_model_region_and_return_specs( + model["model_id"], version, JumpStartScriptScope.INFERENCE, self.region + ) + model["version"] = model_specs.version model_version_list.append(model) # Find synced JumpStart model versions in the Hub js_models_in_hub = [] - for model in hub_models: + for hub_model in hub_models: # TODO: extract both in one pass jumpstart_model_id = next( - (tag for tag in model.search_keywords if tag.startswith("@jumpstart-model-id")), + ( + tag + for tag in hub_model["search_keywords"] + if tag.startswith("@jumpstart-model-id") + ), None, ) jumpstart_model_version = next( ( tag - for tag in model.search_keywords + for tag in hub_model["search_keywords"] if tag.startswith("@jumpstart-model-version") ), None, ) if jumpstart_model_id and jumpstart_model_version: - js_models_in_hub.append(model) + js_models_in_hub.append(hub_model) # Match inputted list of model versions with synced JumpStart model versions in the Hub models_to_sync = [] for model in model_version_list: - model_id, version = model - matched_model = next((model for model in js_models_in_hub if model.name == model_id)) + matched_model = next( + ( + hub_model + for hub_model in js_models_in_hub + if hub_model and hub_model["name"] == model["model_id"] + ), + None, + ) # Model does not exist in Hub, sync if not matched_model: models_to_sync.append(model) if matched_model: + model_version = Version(model["version"]) + hub_model_version = Version(matched_model["version"]) + # 1. Model version exists in Hub, pass - if matched_model.version == version: + if hub_model_version == model_version: pass # 2. Invalid model version exists in Hub, pass # This will only happen if something goes wrong in our metadata - if matched_model.version > version: + if hub_model_version > model_version: pass # 3. Old model version exists in Hub, update - if matched_model.version < version: + if hub_model_version < model_version: # Check minSDKVersion against current SDK version, emit log models_to_sync.append(model) @@ -269,6 +286,7 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self._sagemaker_session, ) + # TODO: Uncomment and implement # studio_specs = self.fetch_studio_specs(model_id=model_name, version=model_version) studio_specs = {} @@ -282,12 +300,17 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): src_files = file_generator.format(model_specs) dest_files = file_generator.format(dest_location) - files_to_copy = list(FileSync(src_files, dest_files, dest_location).call()) + files_to_copy = FileSync(src_files, dest_files, dest_location).call() if len(files_to_copy) > 0: - # Copy files with MPU + # TODO: Copy files with MPU print("hi") + # Tag model if specs say it is deprecated or training/inference vulnerable + # Update tag of HubContent ARN without version. Versioned ARNs are not + # onboarded to Tagris. + tags = [] + hub_content_document = HubContentDocument_v2(spec=model_specs) self._sagemaker_session.import_hub_content( @@ -301,4 +324,5 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): hub_content_description="", hub_content_markdown="", hub_content_search_keywords=[], + tags=tags, ) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index a2eabddbb8..26d780b667 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -11,14 +11,19 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest import mock +from unittest.mock import patch import pytest from mock import Mock from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec REGION = "us-east-1" ACCOUNT_ID = "123456789123" HUB_NAME = "mock-hub-name" +MODULE_PATH = "sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub" + @pytest.fixture() def sagemaker_session(): @@ -130,3 +135,141 @@ def test_create_with_bucket_name( ) sagemaker_session.create_hub.assert_called_with(**request) assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_kicks_off_parallel_syncs( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = [] + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_has_calls([mock.call(model_one), mock.call(model_two)]) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_filters_models_that_exist_in_hub( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + {"name": "mock-model-three-nonsense", "version": "1.0.2", "search_keywords": []}, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_called_once_with(model_one) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_updates_old_models_in_hub( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.1", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_has_calls([mock.call(model_one), mock.call(model_two)]) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch(f"{MODULE_PATH}._sync_public_model_to_hub") +@patch(f"{MODULE_PATH}.list_models") +def test_sync_passes_newer_hub_models( + mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session +): + mock_get_model_specs.side_effect = get_spec_from_base_spec + mock_list_models.return_value = [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + hub_name = "mock_hub_name" + model_one = {"model_id": "mock-model-one-huggingface"} + model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} + mock_sync_public_models.return_value = "" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + hub.sync([model_one, model_two]) + + mock_sync_public_models.assert_called_once_with(model_one) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index d87a2f957b..db1102efb2 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -873,7 +873,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - assert mocked_open.call_count == 2 # TODO: ?? + mocked_open.assert_not_called() mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), From 67d8ec8532cb45c6a446c1524a0c8c1bd2268402 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 5 Mar 2024 20:14:30 +0000 Subject: [PATCH 03/15] linting --- .../jumpstart/curated_hub/accessors/filegenerator.py | 11 ++++------- src/sagemaker/jumpstart/curated_hub/accessors/sync.py | 7 ++----- .../jumpstart/curated_hub/accessors/synccomparator.py | 6 ++++-- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py index b8a6b299e8..8761af535e 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py @@ -35,9 +35,9 @@ def __init__( @singledispatchmethod def format(self, file_input) -> List[FileInfo]: - """Dispatch method that takes in an input of either ``S3ObjectLocation`` or - ``JumpStartModelSpecs`` and is implemented in below registered functions. - """ + """Dispatch method that is implemented in below registered functions. + + Takes in an input of either ``S3ObjectLocation`` or ``JumpStartModelSpecs``.""" # pylint: disable=W0107 pass @@ -81,10 +81,7 @@ def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: def specs_format( self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any] ) -> List[FileInfo]: - """ - Collects data locations from JumpStart public model specs and - converts into FileInfo. - """ + """Collects data locations from JumpStart public model specs and converts into FileInfo.""" public_model_data_accessor = PublicModelDataAccessor( region=self.region, model_specs=file_input, studio_specs=studio_specs ) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py index 8c642d222e..6cc1353f5f 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py @@ -26,9 +26,7 @@ class FileSync: """Something.""" def __init__(self, src_files: List[FileInfo], dest_files: List[FileInfo], dest_bucket: str): - """Instantiates a ``FileSync`` class. - Sorts src and dest files by name for easier - comparisons. + """Instantiates a ``FileSync`` class. Sorts src and dest files by name for comparisons. Args: src_files (List[FileInfo]): List of files to sync with destination @@ -41,8 +39,7 @@ def __init__(self, src_files: List[FileInfo], dest_files: List[FileInfo], dest_b self.dest_bucket = dest_bucket def call(self) -> Generator[FileInfo, FileInfo, FileInfo]: - """This function performs the actual comparisons. Returns a list of FileInfo - to copy. + """This function performs the actual comparisons. Returns a list of FileInfo to copy. Algorithm: Loop over sorted files in the src directory. If at the end of dest directory, diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py index bfbae6b617..f3a9a3ca7b 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py @@ -38,14 +38,16 @@ def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool return should_sync def compare_size(self, src_file: FileInfo, dest_file: FileInfo): - """ + """Compares sizes of src and dest files. + :returns: True if the sizes are the same. False otherwise. """ return src_file.size == dest_file.size def compare_time(self, src_file: FileInfo, dest_file: FileInfo): - """ + """Compares time delta between src and dest files. + :returns: True if the file does not need updating based on time of last modification and type of operation. False if the file does need updating based on the time of From 2fa0503c52de0e65a3af037933f98c95084ee938 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 5 Mar 2024 20:30:57 +0000 Subject: [PATCH 04/15] file generator class inheritance --- .../curated_hub/accessors/filegenerator.py | 46 +++++-------- .../accessors/public_model_data.py | 69 ++++++------------- .../jumpstart/curated_hub/curated_hub.py | 7 +- .../curated_hub/test_filegenerator.py | 39 +++-------- 4 files changed, 48 insertions(+), 113 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py index 8761af535e..7d86e1f54c 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """This module contains important utilities related to HubContent data files.""" from __future__ import absolute_import -from functools import singledispatchmethod from typing import Any, Dict, List, Optional from botocore.client import BaseClient @@ -33,35 +32,19 @@ def __init__( self.s3_client = s3_client self.studio_specs = studio_specs - @singledispatchmethod def format(self, file_input) -> List[FileInfo]: - """Dispatch method that is implemented in below registered functions. + """Dispatch method that is implemented in below registered functions.""" + raise NotImplementedError - Takes in an input of either ``S3ObjectLocation`` or ``JumpStartModelSpecs``.""" - # pylint: disable=W0107 - pass - @format.register - def _(self, file_input: S3ObjectLocation) -> List[FileInfo]: - """Implements ``.format`` when the input is of type ``S3ObjectLocation``. +class S3PathFileGenerator(FileGenerator): + """Utility class to help format all objects in an S3 bucket.""" - Returns a list of ``FileInfo`` objects from the specified bucket location. - """ - files = self.s3_format(file_input) - return files - - @format.register - def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: - """Implements ``.format`` when the input is of type ``JumpStartModelSpecs``. + def format(self, file_input: S3ObjectLocation) -> List[FileInfo]: + """Retrieves data from an S3 bucket and formats into FileInfo. - Returns a list of ``FileInfo`` objects from dependencies found in the public - model specs. + Returns a list of ``FileInfo`` objects from the specified bucket location. """ - files = self.specs_format(file_input, self.studio_specs) - return files - - def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: - """Retrieves data from a bucket and formats into FileInfo""" parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key} response = self.s3_client.list_objects_v2(**parameters) contents = response.get("Contents", None) @@ -78,12 +61,17 @@ def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: files.append(FileInfo(key, size, last_modified)) return files - def specs_format( - self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any] - ) -> List[FileInfo]: - """Collects data locations from JumpStart public model specs and converts into FileInfo.""" +class ModelSpecsFileGenerator(FileGenerator): + """Utility class to help format all data paths from JumpStart public model specs.""" + + def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: + """Collects data locations from JumpStart public model specs and converts into FileInfo`. + + Returns a list of ``FileInfo`` objects from dependencies found in the public + model specs. + """ public_model_data_accessor = PublicModelDataAccessor( - region=self.region, model_specs=file_input, studio_specs=studio_specs + region=self.region, model_specs=file_input, studio_specs=self.studio_specs ) files = [] for dependency in HubContentDependencyType: diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index b1438f86dd..a5266f1745 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -47,86 +47,57 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType): @property def inference_artifact_s3_reference(self): - """Retrieves s3 reference for model inference artifact""" - return self._get_inference_artifact_s3_reference() - - @property - def training_artifact_s3_reference(self): - """Retrieves s3 reference for model training artifact""" - return self._get_training_artifact_s3_reference() - - @property - def inference_script_s3_reference(self): - """Retrieves s3 reference for model inference script""" - return self._get_inference_script_s3_reference() - - @property - def training_script_s3_reference(self): - """Retrieves s3 reference for model training script""" - return self._get_training_script_s3_reference() - - @property - def default_training_dataset_s3_reference(self): - """Retrieves s3 reference for s3 directory containing model training datasets""" - return self._get_default_training_dataset_s3_reference() - - @property - def demo_notebook_s3_reference(self): - """Retrieves s3 reference for model demo jupyter notebook""" - return self._get_demo_notebook_s3_reference() - - @property - def markdown_s3_reference(self): - """Retrieves s3 reference for model markdown""" - return self._get_markdown_s3_reference() - - def _get_bucket_name(self) -> str: - """Retrieves s3 bucket""" - return self._bucket - - def _get_inference_artifact_s3_reference(self) -> S3ObjectLocation: """Retrieves s3 reference for model inference artifact""" return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) ) - def _get_training_artifact_s3_reference(self) -> S3ObjectLocation: + @property + def training_artifact_s3_reference(self): """Retrieves s3 reference for model training artifact""" return create_s3_object_reference_from_uri( self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) ) - def _get_inference_script_s3_reference(self) -> S3ObjectLocation: + @property + def inference_script_s3_reference(self): """Retrieves s3 reference for model inference script""" return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) ) - def _get_training_script_s3_reference(self) -> S3ObjectLocation: + @property + def training_script_s3_reference(self): """Retrieves s3 reference for model training script""" return create_s3_object_reference_from_uri( self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) ) - def _get_default_training_dataset_s3_reference(self) -> S3ObjectLocation: + @property + def default_training_dataset_s3_reference(self): """Retrieves s3 reference for s3 directory containing model training datasets""" return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) - def __get_training_dataset_prefix(self) -> str: - """Retrieves training dataset location""" - return self.studio_specs["defaultDataKey"] - - def _get_demo_notebook_s3_reference(self) -> S3ObjectLocation: + @property + def demo_notebook_s3_reference(self): """Retrieves s3 reference for model demo jupyter notebook""" framework = get_model_framework(self.model_specs) key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" return S3ObjectLocation(self._get_bucket_name(), key) - - def _get_markdown_s3_reference(self) -> S3ObjectLocation: + + @property + def markdown_s3_reference(self): """Retrieves s3 reference for model markdown""" framework = get_model_framework(self.model_specs) key = f"{framework}-metadata/{self.model_specs.model_id}.md" return S3ObjectLocation(self._get_bucket_name(), key) + def _get_bucket_name(self) -> str: + """Retrieves s3 bucket""" + return self._bucket + + def __get_training_dataset_prefix(self) -> str: + """Retrieves training dataset location""" + return self.studio_specs["defaultDataKey"] def _jumpstart_script_s3_uri(self, model_scope: str) -> str: """Retrieves JumpStart script s3 location""" diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 44a7c0ca9d..f083010ca9 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -20,7 +20,7 @@ from botocore.client import BaseClient from packaging.version import Version -from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator +from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator, ModelSpecsFileGenerator, S3PathFileGenerator from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation from sagemaker.jumpstart.curated_hub.accessors.sync import FileSync from sagemaker.jumpstart.enums import JumpStartScriptScope @@ -296,9 +296,8 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): ) # TODO: Validations? HeadBucket? - file_generator = FileGenerator(self.region, self._s3_client, studio_specs) - src_files = file_generator.format(model_specs) - dest_files = file_generator.format(dest_location) + src_files = ModelSpecsFileGenerator(self.region, self._s3_client, studio_specs).format(model_specs) + dest_files = S3PathFileGenerator(self.region, self._s3_client).format(dest_location) files_to_copy = FileSync(src_files, dest_files, dest_location).call() diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index 65084c10b3..7fe51f9944 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -13,7 +13,7 @@ from __future__ import absolute_import import pytest from unittest.mock import Mock, patch -from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator +from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator, ModelSpecsFileGenerator, S3PathFileGenerator from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation @@ -37,30 +37,7 @@ def s3_client(): return mock_s3_client -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_format_has_two_different_input_pathways(patched_get_model_specs, s3_client): - patched_get_model_specs.side_effect = get_spec_from_base_spec - studio_specs = {"defaultDataKey": "model_id123"} - mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") - generator = FileGenerator("us-west-2", s3_client, studio_specs) - generator.format(mock_hub_bucket) - - s3_client.list_objects_v2.assert_called_once() - s3_client.head_object.assert_not_called() - patched_get_model_specs.assert_not_called() - - # Other mocks shouldn't have been called - s3_client.list_objects_v2.reset_mock() - - specs = JumpStartModelSpecs(BASE_SPEC) - generator.format(specs) - - s3_client.list_objects_v2.assert_not_called() - s3_client.head_object.assert_called() - patched_get_model_specs.assert_called() - - -def test_object_location_input_works(s3_client): +def test_s3_path_file_generator_happy_path(s3_client): s3_client.list_objects_v2.return_value = { "Contents": [ {"Key": "my-key-one", "Size": 123456789, "LastModified": "08-14-1997 00:00:00"}, @@ -69,7 +46,7 @@ def test_object_location_input_works(s3_client): } mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") - generator = FileGenerator("us-west-2", s3_client) + generator = S3PathFileGenerator("us-west-2", s3_client) response = generator.format(mock_hub_bucket) s3_client.list_objects_v2.assert_called_once() @@ -80,12 +57,12 @@ def test_object_location_input_works(s3_client): @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_specs_input_works(patched_get_model_specs, s3_client): +def test_model_specs_file_generator_happy_path(patched_get_model_specs, s3_client): patched_get_model_specs.side_effect = get_spec_from_base_spec specs = JumpStartModelSpecs(BASE_SPEC) studio_specs = {"defaultDataKey": "model_id123"} - generator = FileGenerator("us-west-2", s3_client, studio_specs) + generator = ModelSpecsFileGenerator("us-west-2", s3_client, studio_specs) response = generator.format(specs) s3_client.head_object.assert_called() @@ -102,11 +79,11 @@ def test_specs_input_works(patched_get_model_specs, s3_client): ] -def test_object_location_no_objects(s3_client): +def test_s3_path_file_generator_with_no_objects(s3_client): s3_client.list_objects_v2.return_value = {"Contents": []} mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") - generator = FileGenerator("us-west-2", s3_client) + generator = S3PathFileGenerator("us-west-2", s3_client) response = generator.format(mock_hub_bucket) s3_client.list_objects_v2.assert_called_once() @@ -117,7 +94,7 @@ def test_object_location_no_objects(s3_client): s3_client.list_objects_v2.return_value = {} mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") - generator = FileGenerator("us-west-2", s3_client) + generator = S3PathFileGenerator("us-west-2", s3_client) response = generator.format(mock_hub_bucket) s3_client.list_objects_v2.assert_called_once() From 30c2b91293bcfdda39301f04e1043bb6f04523a4 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 5 Mar 2024 20:34:19 +0000 Subject: [PATCH 05/15] lint --- .../jumpstart/curated_hub/accessors/filegenerator.py | 1 + .../jumpstart/curated_hub/accessors/public_model_data.py | 3 ++- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 9 +++++++-- .../jumpstart/curated_hub/test_filegenerator.py | 5 ++++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py index 7d86e1f54c..1fb50adb50 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py @@ -61,6 +61,7 @@ def format(self, file_input: S3ObjectLocation) -> List[FileInfo]: files.append(FileInfo(key, size, last_modified)) return files + class ModelSpecsFileGenerator(FileGenerator): """Utility class to help format all data paths from JumpStart public model specs.""" diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index a5266f1745..c6f0c61262 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -84,13 +84,14 @@ def demo_notebook_s3_reference(self): framework = get_model_framework(self.model_specs) key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" return S3ObjectLocation(self._get_bucket_name(), key) - + @property def markdown_s3_reference(self): """Retrieves s3 reference for model markdown""" framework = get_model_framework(self.model_specs) key = f"{framework}-metadata/{self.model_specs.model_id}.md" return S3ObjectLocation(self._get_bucket_name(), key) + def _get_bucket_name(self) -> str: """Retrieves s3 bucket""" return self._bucket diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index f083010ca9..f5a64bbf06 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -20,7 +20,10 @@ from botocore.client import BaseClient from packaging.version import Version -from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator, ModelSpecsFileGenerator, S3PathFileGenerator +from sagemaker.jumpstart.curated_hub.accessors.filegenerator import ( + ModelSpecsFileGenerator, + S3PathFileGenerator, +) from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation from sagemaker.jumpstart.curated_hub.accessors.sync import FileSync from sagemaker.jumpstart.enums import JumpStartScriptScope @@ -296,7 +299,9 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): ) # TODO: Validations? HeadBucket? - src_files = ModelSpecsFileGenerator(self.region, self._s3_client, studio_specs).format(model_specs) + src_files = ModelSpecsFileGenerator(self.region, self._s3_client, studio_specs).format( + model_specs + ) dest_files = S3PathFileGenerator(self.region, self._s3_client).format(dest_location) files_to_copy = FileSync(src_files, dest_files, dest_location).call() diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index 7fe51f9944..2e9aa1a73f 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -13,7 +13,10 @@ from __future__ import absolute_import import pytest from unittest.mock import Mock, patch -from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator, ModelSpecsFileGenerator, S3PathFileGenerator +from sagemaker.jumpstart.curated_hub.accessors.filegenerator import ( + ModelSpecsFileGenerator, + S3PathFileGenerator, +) from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation From ef57f14f3f4f8615489e2a1b8f4b129490ddb857 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 6 Mar 2024 19:09:53 +0000 Subject: [PATCH 06/15] multipart copy and algorithm updates --- .../curated_hub/accessors/filegenerator.py | 41 +++-- .../curated_hub/accessors/fileinfo.py | 12 +- .../curated_hub/accessors/multipartcopy.py | 128 +++++++++++++++ .../curated_hub/accessors/objectlocation.py | 9 +- .../jumpstart/curated_hub/accessors/sync.py | 80 ++++++++-- .../curated_hub/accessors/synccomparator.py | 7 +- .../jumpstart/curated_hub/curated_hub.py | 57 ++++--- src/sagemaker/jumpstart/curated_hub/types.py | 8 +- .../jumpstart/curated_hub/test_curated_hub.py | 150 ++++++++++-------- .../curated_hub/test_filegenerator.py | 50 ++++-- .../curated_hub/test_synccomparator.py | 12 +- 11 files changed, 410 insertions(+), 144 deletions(-) create mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py index 1fb50adb50..828158377e 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Any, Dict, List, Optional +from datetime import datetime from botocore.client import BaseClient from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType @@ -58,7 +59,7 @@ def format(self, file_input: S3ObjectLocation) -> List[FileInfo]: key: str = s3_obj.get("Key") size: bytes = s3_obj.get("Size", None) last_modified: str = s3_obj.get("LastModified", None) - files.append(FileInfo(key, size, last_modified)) + files.append(FileInfo(file_input.bucket, key, size, last_modified)) return files @@ -76,12 +77,34 @@ def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: ) files = [] for dependency in HubContentDependencyType: - location = public_model_data_accessor.get_s3_reference(dependency) - parameters = {"Bucket": location.bucket, "Prefix": location.key} - response = self.s3_client.head_object(**parameters) - key: str = location.key - size: bytes = response.get("ContentLength", None) - last_updated: str = response.get("LastModified", None) - dependency_type: HubContentDependencyType = dependency - files.append(FileInfo(key, size, last_updated, dependency_type)) + location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) + + # Prefix + if location.key[-1] == "/": + parameters = {"Bucket": location.bucket, "Prefix": location.key} + response = self.s3_client.list_objects_v2(**parameters) + contents = response.get("Contents", None) + for s3_obj in contents: + key: str = s3_obj.get("Key") + size: bytes = s3_obj.get("Size", None) + last_modified: datetime = s3_obj.get("LastModified", None) + dependency_type: HubContentDependencyType = dependency + files.append( + FileInfo( + location.bucket, + key, + size, + last_modified, + dependency_type, + ) + ) + else: + parameters = {"Bucket": location.bucket, "Key": location.key} + response = self.s3_client.head_object(**parameters) + size: bytes = response.get("ContentLength", None) + last_updated: datetime = response.get("LastModified", None) + dependency_type: HubContentDependencyType = dependency + files.append( + FileInfo(location.bucket, location.key, size, last_updated, dependency_type) + ) return files diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py b/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py index d2091215fb..ea418aebad 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py @@ -16,6 +16,9 @@ from enum import Enum from dataclasses import dataclass from typing import Optional +from datetime import datetime + +from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation class HubContentDependencyType(str, Enum): @@ -34,14 +37,17 @@ class HubContentDependencyType(str, Enum): class FileInfo: """Data class for additional S3 file info.""" + location: S3ObjectLocation + def __init__( self, - name: str, + bucket: str, + key: str, size: Optional[bytes], - last_updated: Optional[str], + last_updated: Optional[datetime], dependecy_type: Optional[HubContentDependencyType] = None, ): - self.name = name + self.location = S3ObjectLocation(bucket, key) self.size = size self.last_updated = last_updated self.dependecy_type = dependecy_type diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py new file mode 100644 index 0000000000..9ea7971f41 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -0,0 +1,128 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" +from __future__ import absolute_import +from typing import List + +import boto3 +import botocore +import tqdm + +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo +from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation + +s3transfer = boto3.s3.transfer + + +# pylint: disable=R1705,R1710 +def human_readable_size(value): + """Convert a size in bytes into a human readable format. + + For example:: + + >>> human_readable_size(1) + '1 Byte' + >>> human_readable_size(10) + '10 Bytes' + >>> human_readable_size(1024) + '1.0 KiB' + >>> human_readable_size(1024 * 1024) + '1.0 MiB' + + :param value: The size in bytes. + :return: The size in a human readable format based on base-2 units. + + """ + base = 1024 + bytes_int = float(value) + + if bytes_int == 1: + return "1 Byte" + elif bytes_int < base: + return "%d Bytes" % bytes_int + + for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")): + unit = base ** (i + 2) + if round((bytes_int / unit) * base) < base: + return "%.1f %s" % ((base * bytes_int / unit), suffix) + + +class MultiPartCopyHandler(object): + """Multi Part Copy Handler class.""" + + WORKERS = 20 + MULTIPART_CONFIG = 8 * (1024**2) + + def __init__( + self, + region: str, + files: List[FileInfo], + dest_location: S3ObjectLocation, + ): + """Something.""" + self.region = region + self.files = files + self.dest_location = dest_location + + config = botocore.config.Config(max_pool_connections=self.WORKERS) + self.s3_client = boto3.client("s3", region_name=self.region, config=config) + transfer_config = s3transfer.TransferConfig( + multipart_threshold=self.MULTIPART_CONFIG, + multipart_chunksize=self.MULTIPART_CONFIG, + max_bandwidth=True, + use_threads=True, + max_concurrency=self.WORKERS, + ) + self.transfer_manager = s3transfer.create_transfer_manager( + client=self.s3_client, config=transfer_config + ) + + def _copy_file(self, file: FileInfo, update_fn): + """Something.""" + copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} + result = self.transfer_manager.copy( + bucket=self.dest_location.bucket, + key=f"{self.dest_location.key}/{file.location.key}", + copy_source=copy_source, + subscribers=[ + s3transfer.ProgressCallbackInvoker(update_fn), + ], + ) + result.result() + + def call(self): + """Something.""" + total_size = sum([file.size for file in self.files]) + JUMPSTART_LOGGER.warning( + "Copying %s files (%s) into %s/%s", + len(self.files), + human_readable_size(total_size), + self.dest_location.bucket, + self.dest_location.key, + ) + + progress = tqdm.tqdm( + desc="JumpStart Sync", + total=total_size, + unit="B", + unit_scale=1, + position=0, + bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}", + ) + + for file in self.files: + self._copy_file(file, progress.update) + + self.transfer_manager.shutdown() + progress.close() diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py b/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py index b0512bed2e..bb72152ee1 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import Dict +from sagemaker.s3_utils import parse_s3_url + @dataclass class S3ObjectLocation: @@ -37,10 +39,9 @@ def get_uri(self) -> str: def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: """Utiity to help generate an S3 object reference""" - uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1) - uri_split = uri_with_s3_prefix_removed.split("/") + bucket, key = parse_s3_url(s3_uri) return S3ObjectLocation( - bucket=uri_split[0], - key="/".join(uri_split[1:]) if len(uri_split) > 1 else "", + bucket=bucket, + key=key, ) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py index 6cc1353f5f..b724f5daf0 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/sync.py @@ -12,9 +12,11 @@ # language governing permissions and limitations under the License. """This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" from __future__ import absolute_import +from dataclasses import dataclass from typing import Generator, List from botocore.compat import six +from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation from sagemaker.jumpstart.curated_hub.accessors.synccomparator import SizeAndLastUpdatedComparator from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo @@ -22,10 +24,26 @@ advance_iterator = six.advance_iterator +@dataclass +class FileSyncResult: + """File Sync Result class""" + + files: List[FileInfo] + destination: S3ObjectLocation + + def __init__( + self, files_to_copy: Generator[FileInfo, FileInfo, FileInfo], destination: S3ObjectLocation + ): + self.files = list(files_to_copy) + self.destination = destination + + class FileSync: - """Something.""" + """FileSync class.""" - def __init__(self, src_files: List[FileInfo], dest_files: List[FileInfo], dest_bucket: str): + def __init__( + self, src_files: List[FileInfo], dest_files: List[FileInfo], destination: S3ObjectLocation + ): """Instantiates a ``FileSync`` class. Sorts src and dest files by name for comparisons. Args: @@ -34,11 +52,19 @@ def __init__(self, src_files: List[FileInfo], dest_files: List[FileInfo], dest_b dest_bucket (str): Destination bucket name for copied data """ self.comparator = SizeAndLastUpdatedComparator() - self.src_files: List[FileInfo] = sorted(src_files, lambda x: x.name) - self.dest_files: List[FileInfo] = sorted(dest_files, lambda x: x.name) - self.dest_bucket = dest_bucket + self.src_files: List[FileInfo] = sorted(src_files, key=lambda x: x.location.key) + self.dest_files: List[FileInfo] = sorted(dest_files, key=lambda x: x.location.key) + self.destination = destination - def call(self) -> Generator[FileInfo, FileInfo, FileInfo]: + def call(self) -> FileSyncResult: + """Determines which files to copy based on the comparator. + + Returns a ``FileSyncResult`` object. + """ + files_to_copy = self._determine_files_to_copy() + return FileSyncResult(files_to_copy, self.destination) + + def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: """This function performs the actual comparisons. Returns a list of FileInfo to copy. Algorithm: @@ -51,21 +77,47 @@ def call(self) -> Generator[FileInfo, FileInfo, FileInfo]: """ # :var dest_done: True if there are no files from the dest left. dest_done = False + iterator = iter(self.dest_files) + # Begin by advancing the iterator to the first file + try: + dest_file: FileInfo = advance_iterator(iterator) + except StopIteration: + dest_done = True + for src_file in self.src_files: + # End of dest, yield remaining src_files if dest_done: yield src_file continue - while not dest_done: + # We've identified two files that have the same name, further compare + if self._is_same_file_name(src_file.location.key, dest_file.location.key): + should_sync = self.comparator.determine_should_sync(src_file, dest_file) + + if should_sync: + yield src_file + + # Increment dest_files and src_files try: - dest_file: FileInfo = advance_iterator(self.dest_files) + dest_file: FileInfo = advance_iterator(iterator) except StopIteration: dest_done = True - break + continue + + # Past the src file alphabetically in dest file list. Take the src file and continue + if self._is_alphabetically_larger_file_name( + src_file.location.key, dest_file.location.key + ): + yield src_file + continue - if src_file.name == dest_file.name: - should_sync = self.comparator.determine_should_sync(src_file, dest_file) + def _is_same_file_name(self, src_filename: str, dest_filename: str) -> bool: + """Compares two file names and determiens if they are the same. + + Destination files might add a prefix. + """ + return dest_filename.endswith(src_filename) - if should_sync: - yield src_file - break + def _is_alphabetically_larger_file_name(self, src_filename: str, dest_filename: str) -> bool: + """Determines if one filename is alphabetically earlier than another.""" + return src_filename > dest_filename diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py index f3a9a3ca7b..6103d56d72 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py @@ -13,6 +13,7 @@ """This module provides comparators for syncing s3 files.""" from __future__ import absolute_import from datetime import timedelta +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo @@ -26,10 +27,10 @@ def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool same_last_modified_time = self.compare_time(src_file, dest_file) should_sync = (not same_size) or (not same_last_modified_time) if should_sync: - print( + JUMPSTART_LOGGER.warning( "syncing: %s -> %s, size: %s -> %s, modified time: %s -> %s", - src_file.name, - src_file.name, + src_file.location.key, + src_file.location.key, src_file.size, dest_file.size, src_file.last_updated, diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index f5a64bbf06..29c3f05289 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -13,27 +13,31 @@ """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import from concurrent import futures +import json import traceback from typing import Optional, Dict, List, Any import boto3 +from botocore import exceptions from botocore.client import BaseClient from packaging.version import Version +from sagemaker.jumpstart import utils from sagemaker.jumpstart.curated_hub.accessors.filegenerator import ( ModelSpecsFileGenerator, S3PathFileGenerator, ) +from sagemaker.jumpstart.curated_hub.accessors.multipartcopy import MultiPartCopyHandler from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation from sagemaker.jumpstart.curated_hub.accessors.sync import FileSync from sagemaker.jumpstart.enums import JumpStartScriptScope -from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from sagemaker.session import Session -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.types import ( DescribeHubResponse, DescribeHubContentsResponse, HubContentType, + JumpStartModelSpecs, ) from sagemaker.jumpstart.curated_hub.utils import ( create_hub_bucket_if_it_does_not_exist, @@ -73,11 +77,12 @@ def _fetch_hub_bucket_name(self) -> str: """Retrieves hub bucket name from Hub config if exists""" try: hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) - hub_bucket_prefix = hub_response["S3StorageConfig"]["S3OutputPath"] - return hub_bucket_prefix # TODO: Strip s3:// prefix - except ValueError: + hub_bucket_prefix = hub_response["S3StorageConfig"].get("S3OutputPath", None) + if hub_bucket_prefix: + return hub_bucket_prefix.replace("s3://", "") + return generate_default_hub_bucket_name(self._sagemaker_session) + except exceptions.ClientError: hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) - print(f"Hub bucket name is: {hub_bucket_name}") # TODO: Better messaging return hub_bucket_name def create( @@ -98,7 +103,7 @@ def create( hub_description=description, hub_display_name=display_name, hub_search_keywords=search_keywords, - hub_bucket_name=bucket_name, + s3_storage_config={"S3OutputPath": f"s3://{bucket_name}"}, tags=tags, ) @@ -182,7 +187,7 @@ def sync(self, model_list: List[Dict[str, str]]): for model in model_list: version = model.get("version", "*") if not version or version == "*": - model_specs = verify_model_region_and_return_specs( + model_specs = utils.verify_model_region_and_return_specs( model["model_id"], version, JumpStartScriptScope.INFERENCE, self.region ) model["version"] = model_specs.version @@ -190,7 +195,7 @@ def sync(self, model_list: List[Dict[str, str]]): # Find synced JumpStart model versions in the Hub js_models_in_hub = [] - for hub_model in hub_models: + for hub_model in hub_models["HubContentSummaries"]: # TODO: extract both in one pass jumpstart_model_id = next( ( @@ -279,43 +284,42 @@ def sync(self, model_list: List[Dict[str, str]]): def _sync_public_model_to_hub(self, model: Dict[str, str]): """Syncs a public JumpStart model version to the Hub. Runs in parallel.""" - model_name = model["name"] + model_name = model["model_id"] model_version = model["version"] - model_specs = verify_model_region_and_return_specs( + model_specs = utils.verify_model_region_and_return_specs( model_id=model_name, version=model_version, region=self.region, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self._sagemaker_session, ) - - # TODO: Uncomment and implement - # studio_specs = self.fetch_studio_specs(model_id=model_name, version=model_version) - studio_specs = {} + studio_specs = self._fetch_studio_specs(model_specs=model_specs) dest_location = S3ObjectLocation( bucket=self.hub_bucket_name, key=f"{model_name}/{model_version}" ) - # TODO: Validations? HeadBucket? src_files = ModelSpecsFileGenerator(self.region, self._s3_client, studio_specs).format( model_specs ) dest_files = S3PathFileGenerator(self.region, self._s3_client).format(dest_location) - files_to_copy = FileSync(src_files, dest_files, dest_location).call() + sync_result = FileSync(src_files, dest_files, dest_location).call() - if len(files_to_copy) > 0: - # TODO: Copy files with MPU - print("hi") + if len(sync_result.files) > 0: + MultiPartCopyHandler( + region=self.region, files=sync_result.files, dest_location=sync_result.destination + ).call() + else: + JUMPSTART_LOGGER.warning("[%s/%s] Nothing to copy", model_name, model_version) # Tag model if specs say it is deprecated or training/inference vulnerable # Update tag of HubContent ARN without version. Versioned ARNs are not # onboarded to Tagris. tags = [] - hub_content_document = HubContentDocument_v2(spec=model_specs) + hub_content_document = HubContentDocument_v2(spec=model_specs).__str__() self._sagemaker_session.import_hub_content( document_schema_version=HubContentDocument_v2.SCHEMA_VERSION, @@ -330,3 +334,14 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): hub_content_search_keywords=[], tags=tags, ) + + def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any]: + """Fetches StudioSpecs given a models' SDK Specs.""" + model_id = model_specs.model_id + model_version = model_specs.version + + key = f"studio_models/{model_id}/studio_specs_v{model_version}.json" + response = self._s3_client.get_object( + Bucket=utils.get_jumpstart_content_bucket(self.region), Key=key + ) + return json.loads(response["Body"].read().decode("utf-8")) diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index 4de41d9df7..b788bc1d7c 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from typing import Dict, Any -from sagemaker.jumpstart.types import JumpStartDataHolderType +from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs class HubContentDocument_v2(JumpStartDataHolderType): @@ -28,13 +28,13 @@ def __init__(self, spec: Dict[str, Any]): Args: spec (Dict[str, Any]): Dictionary representation of spec. """ - self.from_json(spec) + self.from_specs(spec) - def from_json(self, json_obj: Dict[str, Any]) -> None: + def from_specs(self, model_specs: JumpStartModelSpecs) -> None: """Sets fields in object based on json. Args: json_obj (Dict[str, Any]): Dictionary representatino of spec. """ # TODO: Implement - self.Url: str = json_obj["url"] + self.Url: str = model_specs.url diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 26d780b667..d9c4a37e94 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -73,14 +73,16 @@ def test_create_with_no_bucket_name( ): create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - sagemaker_session.describe_hub.return_value = {"S3StorageConfig": {"S3OutputPath": None}} + sagemaker_session.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": hub_bucket_name} + } hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, "hub_description": hub_description, - "hub_bucket_name": "sagemaker-hubs-us-east-1-123456789123", "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, + "s3_storage_config": {"S3OutputPath": "s3://sagemaker-hubs-us-east-1-123456789123"}, "tags": tags, } response = hub.create( @@ -118,13 +120,15 @@ def test_create_with_bucket_name( ): create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + hub = CuratedHub( + hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name + ) request = { "hub_name": hub_name, "hub_description": hub_description, - "hub_bucket_name": hub_bucket_name, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, + "s3_storage_config": {"S3OutputPath": "s3://mock-bucket-123"}, "tags": tags, } response = hub.create( @@ -144,7 +148,7 @@ def test_sync_kicks_off_parallel_syncs( mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session ): mock_get_model_specs.side_effect = get_spec_from_base_spec - mock_list_models.return_value = [] + mock_list_models.return_value = {"HubContentSummaries": []} hub_name = "mock_hub_name" model_one = {"model_id": "mock-model-one-huggingface"} model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} @@ -163,25 +167,27 @@ def test_sync_filters_models_that_exist_in_hub( mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session ): mock_get_model_specs.side_effect = get_spec_from_base_spec - mock_list_models.return_value = [ - { - "name": "mock-model-two-pytorch", - "version": "1.0.2", - "search_keywords": [ - "@jumpstart-model-id:model-two-pytorch", - "@jumpstart-model-version:1.0.2", - ], - }, - {"name": "mock-model-three-nonsense", "version": "1.0.2", "search_keywords": []}, - { - "name": "mock-model-four-huggingface", - "version": "2.0.2", - "search_keywords": [ - "@jumpstart-model-id:model-four-huggingface", - "@jumpstart-model-version:2.0.2", - ], - }, - ] + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + {"name": "mock-model-three-nonsense", "version": "1.0.2", "search_keywords": []}, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } hub_name = "mock_hub_name" model_one = {"model_id": "mock-model-one-huggingface"} model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} @@ -200,29 +206,31 @@ def test_sync_updates_old_models_in_hub( mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session ): mock_get_model_specs.side_effect = get_spec_from_base_spec - mock_list_models.return_value = [ - { - "name": "mock-model-two-pytorch", - "version": "1.0.1", - "search_keywords": [ - "@jumpstart-model-id:model-two-pytorch", - "@jumpstart-model-version:1.0.2", - ], - }, - { - "name": "mock-model-three-nonsense", - "version": "1.0.2", - "search_keywords": ["tag-one", "tag-two"], - }, - { - "name": "mock-model-four-huggingface", - "version": "2.0.2", - "search_keywords": [ - "@jumpstart-model-id:model-four-huggingface", - "@jumpstart-model-version:2.0.2", - ], - }, - ] + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.1", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } hub_name = "mock_hub_name" model_one = {"model_id": "mock-model-one-huggingface"} model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} @@ -241,29 +249,31 @@ def test_sync_passes_newer_hub_models( mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session ): mock_get_model_specs.side_effect = get_spec_from_base_spec - mock_list_models.return_value = [ - { - "name": "mock-model-two-pytorch", - "version": "1.0.3", - "search_keywords": [ - "@jumpstart-model-id:model-two-pytorch", - "@jumpstart-model-version:1.0.2", - ], - }, - { - "name": "mock-model-three-nonsense", - "version": "1.0.2", - "search_keywords": ["tag-one", "tag-two"], - }, - { - "name": "mock-model-four-huggingface", - "version": "2.0.2", - "search_keywords": [ - "@jumpstart-model-id:model-four-huggingface", - "@jumpstart-model-version:2.0.2", - ], - }, - ] + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } hub_name = "mock_hub_name" model_one = {"model_id": "mock-model-one-huggingface"} model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"} diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index 2e9aa1a73f..0b54d694e2 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -54,8 +54,8 @@ def test_s3_path_file_generator_happy_path(s3_client): s3_client.list_objects_v2.assert_called_once() assert response == [ - FileInfo("my-key-one", 123456789, "08-14-1997 00:00:00"), - FileInfo("my-key-one", 10101010, "08-14-1997 00:00:00"), + FileInfo("mock-bucket", "my-key-one", 123456789, "08-14-1997 00:00:00"), + FileInfo("mock-bucket", "my-key-one", 10101010, "08-14-1997 00:00:00"), ] @@ -70,15 +70,45 @@ def test_model_specs_file_generator_happy_path(patched_get_model_specs, s3_clien s3_client.head_object.assert_called() patched_get_model_specs.assert_called() - # TODO: Figure out why object attrs aren't being compared + assert response == [ - FileInfo("my-key-one", 123456789, "08-14-1997 00:00:00"), - FileInfo("my-key-two", 123456789, "08-14-1997 00:00:00"), - FileInfo("my-key-three", 123456789, "08-14-1997 00:00:00"), - FileInfo("my-key-four", 123456789, "08-14-1997 00:00:00"), - FileInfo("my-key-five", 123456789, "08-14-1997 00:00:00"), - FileInfo("my-key-six", 123456789, "08-14-1997 00:00:00"), - FileInfo("my-key-seven", 123456789, "08-14-1997 00:00:00"), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo("jumpstart-cache-prod-us-west-2", "model_id123", 123456789, "08-14-1997 00:00:00"), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-notebooks/pytorch-ic-mobilenet-v2-inference.ipynb", + 123456789, + "08-14-1997 00:00:00", + ), + FileInfo( + "jumpstart-cache-prod-us-west-2", + "pytorch-metadata/pytorch-ic-mobilenet-v2.md", + 123456789, + "08-14-1997 00:00:00", + ), ] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py index 74f7aafc9a..b5e1654f05 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py @@ -22,20 +22,20 @@ class SizeAndLastUpdateComparatorTest(unittest.TestCase): comparator = SizeAndLastUpdatedComparator() def test_identical_files_returns_false(self): - file_one = FileInfo("my-file-one", 123456789, datetime.today()) - file_two = FileInfo("my-file-two", 123456789, datetime.today()) + file_one = FileInfo("bucket", "my-file-one", 123456789, datetime.today()) + file_two = FileInfo("bucket", "my-file-two", 123456789, datetime.today()) assert self.comparator.determine_should_sync(file_one, file_two) is False def test_different_file_sizes_returns_true(self): - file_one = FileInfo("my-file-one", 123456789, datetime.today()) - file_two = FileInfo("my-file-two", 10101010, datetime.today()) + file_one = FileInfo("bucket", "my-file-one", 123456789, datetime.today()) + file_two = FileInfo("bucket", "my-file-two", 10101010, datetime.today()) assert self.comparator.determine_should_sync(file_one, file_two) is True def test_different_file_dates_returns_true(self): # change ordering of datetime.today() calls to trigger update - file_two = FileInfo("my-file-two", 123456789, datetime.today()) - file_one = FileInfo("my-file-one", 123456789, datetime.today()) + file_two = FileInfo("bucket", "my-file-two", 123456789, datetime.today()) + file_one = FileInfo("bucket", "my-file-one", 123456789, datetime.today()) assert self.comparator.determine_should_sync(file_one, file_two) is True From c44acd293274366fa1e0b1114b3d9169d4b7637e Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 6 Mar 2024 22:38:58 +0000 Subject: [PATCH 07/15] modularize sync --- .../curated_hub/accessors/filegenerator.py | 151 ++++++++---------- .../curated_hub/accessors/multipartcopy.py | 4 +- .../accessors/public_model_data.py | 7 +- .../curated_hub/accessors/synccomparator.py | 22 ++- .../accessors/{sync.py => synctask.py} | 43 +++-- .../jumpstart/curated_hub/curated_hub.py | 123 ++++++++------ src/sagemaker/jumpstart/curated_hub/utils.py | 6 - src/sagemaker/jumpstart/types.py | 4 + src/sagemaker/jumpstart/utils.py | 5 + .../curated_hub/test_filegenerator.py | 16 +- 10 files changed, 208 insertions(+), 173 deletions(-) rename src/sagemaker/jumpstart/curated_hub/accessors/{sync.py => synctask.py} (76%) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py index 828158377e..b4612f7b08 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """This module contains important utilities related to HubContent data files.""" from __future__ import absolute_import -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from datetime import datetime from botocore.client import BaseClient @@ -23,88 +23,73 @@ from sagemaker.jumpstart.types import JumpStartModelSpecs -class FileGenerator: - """Utility class to help format HubContent data files.""" - - def __init__( - self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None - ): - self.region = region - self.s3_client = s3_client - self.studio_specs = studio_specs - - def format(self, file_input) -> List[FileInfo]: - """Dispatch method that is implemented in below registered functions.""" - raise NotImplementedError - - -class S3PathFileGenerator(FileGenerator): - """Utility class to help format all objects in an S3 bucket.""" - - def format(self, file_input: S3ObjectLocation) -> List[FileInfo]: - """Retrieves data from an S3 bucket and formats into FileInfo. - - Returns a list of ``FileInfo`` objects from the specified bucket location. - """ - parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key} - response = self.s3_client.list_objects_v2(**parameters) - contents = response.get("Contents", None) - - if not contents: - print("Nothing to download") - return [] - - files = [] - for s3_obj in contents: - key: str = s3_obj.get("Key") - size: bytes = s3_obj.get("Size", None) - last_modified: str = s3_obj.get("LastModified", None) - files.append(FileInfo(file_input.bucket, key, size, last_modified)) - return files - - -class ModelSpecsFileGenerator(FileGenerator): - """Utility class to help format all data paths from JumpStart public model specs.""" - - def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: - """Collects data locations from JumpStart public model specs and converts into FileInfo`. - - Returns a list of ``FileInfo`` objects from dependencies found in the public - model specs. - """ - public_model_data_accessor = PublicModelDataAccessor( - region=self.region, model_specs=file_input, studio_specs=self.studio_specs - ) - files = [] - for dependency in HubContentDependencyType: - location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) - - # Prefix - if location.key[-1] == "/": - parameters = {"Bucket": location.bucket, "Prefix": location.key} - response = self.s3_client.list_objects_v2(**parameters) - contents = response.get("Contents", None) - for s3_obj in contents: - key: str = s3_obj.get("Key") - size: bytes = s3_obj.get("Size", None) - last_modified: datetime = s3_obj.get("LastModified", None) - dependency_type: HubContentDependencyType = dependency - files.append( - FileInfo( - location.bucket, - key, - size, - last_modified, - dependency_type, - ) - ) - else: - parameters = {"Bucket": location.bucket, "Key": location.key} - response = self.s3_client.head_object(**parameters) - size: bytes = response.get("ContentLength", None) - last_updated: datetime = response.get("LastModified", None) +def generate_file_infos_from_s3_location( + location: S3ObjectLocation, s3_client: BaseClient +) -> List[FileInfo]: + """Lists objects from an S3 bucket and formats into FileInfo. + + Returns a list of ``FileInfo`` objects from the specified bucket location. + """ + parameters = {"Bucket": location.bucket, "Prefix": location.key} + response = s3_client.list_objects_v2(**parameters) + contents = response.get("Contents", None) + + if not contents: + return [] + + files = [] + for s3_obj in contents: + key: str = s3_obj.get("Key") + size: bytes = s3_obj.get("Size", None) + last_modified: str = s3_obj.get("LastModified", None) + files.append(FileInfo(location.bucket, key, size, last_modified)) + return files + + +def generate_file_infos_from_model_specs( + model_specs: JumpStartModelSpecs, + studio_specs: Dict[str, Any], + region: str, + s3_client: BaseClient, +) -> List[FileInfo]: + """Collects data locations from JumpStart public model specs and converts into `FileInfo`. + + Returns a list of `FileInfo` objects from dependencies found in the public + model specs. + """ + public_model_data_accessor = PublicModelDataAccessor( + region=region, model_specs=model_specs, studio_specs=studio_specs + ) + files = [] + for dependency in HubContentDependencyType: + location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) + + # Prefix + if location.key.endswith("/"): + parameters = {"Bucket": location.bucket, "Prefix": location.key} + response = s3_client.list_objects_v2(**parameters) + contents = response.get("Contents", None) + for s3_obj in contents: + key: str = s3_obj.get("Key") + size: bytes = s3_obj.get("Size", None) + last_modified: datetime = s3_obj.get("LastModified", None) dependency_type: HubContentDependencyType = dependency files.append( - FileInfo(location.bucket, location.key, size, last_updated, dependency_type) + FileInfo( + location.bucket, + key, + size, + last_modified, + dependency_type, + ) ) - return files + else: + parameters = {"Bucket": location.bucket, "Key": location.key} + response = s3_client.head_object(**parameters) + size: bytes = response.get("ContentLength", None) + last_updated: datetime = response.get("LastModified", None) + dependency_type: HubContentDependencyType = dependency + files.append( + FileInfo(location.bucket, location.key, size, last_updated, dependency_type) + ) + return files diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py index 9ea7971f41..029b6f1f9e 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -26,7 +26,7 @@ # pylint: disable=R1705,R1710 -def human_readable_size(value): +def human_readable_size(value: int) -> str: """Convert a size in bytes into a human readable format. For example:: @@ -101,7 +101,7 @@ def _copy_file(self, file: FileInfo, update_fn): ) result.result() - def call(self): + def execute(self): """Something.""" total_size = sum([file.size for file in self.files]) JUMPSTART_LOGGER.warning( diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index c6f0c61262..6bd4d856a3 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -15,9 +15,6 @@ from typing import Dict, Any from sagemaker import model_uris, script_uris from sagemaker.jumpstart.curated_hub.accessors.fileinfo import HubContentDependencyType -from sagemaker.jumpstart.curated_hub.utils import ( - get_model_framework, -) from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import JumpStartModelSpecs from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @@ -81,14 +78,14 @@ def default_training_dataset_s3_reference(self): @property def demo_notebook_s3_reference(self): """Retrieves s3 reference for model demo jupyter notebook""" - framework = get_model_framework(self.model_specs) + framework = self.model_specs.get_framework() key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" return S3ObjectLocation(self._get_bucket_name(), key) @property def markdown_s3_reference(self): """Retrieves s3 reference for model markdown""" - framework = get_model_framework(self.model_specs) + framework = self.model_specs.get_framework() key = f"{framework}-metadata/{self.model_specs.model_id}.md" return S3ObjectLocation(self._get_bucket_name(), key) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py index 6103d56d72..b5dfad85aa 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py @@ -18,14 +18,25 @@ from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo -class SizeAndLastUpdatedComparator: - """Something.""" +class BaseComparator: + """BaseComparator object to be extended.""" + + def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool: + """Custom comparator to determine if src file and dest file are in sync.""" + raise NotImplementedError + + +class SizeAndLastUpdatedComparator(BaseComparator): + """Comparator that uses file size and last modified time. + + Uses file size (bytes) and last_modified_time (timestamp) to determine sync. + """ def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool: """Determines if src file should be moved to dest folder.""" same_size = self.compare_size(src_file, dest_file) - same_last_modified_time = self.compare_time(src_file, dest_file) - should_sync = (not same_size) or (not same_last_modified_time) + is_newer_dest_file = self.compare_file_updates(src_file, dest_file) + should_sync = (not same_size) or (not is_newer_dest_file) if should_sync: JUMPSTART_LOGGER.warning( "syncing: %s -> %s, size: %s -> %s, modified time: %s -> %s", @@ -46,7 +57,7 @@ def compare_size(self, src_file: FileInfo, dest_file: FileInfo): """ return src_file.size == dest_file.size - def compare_time(self, src_file: FileInfo, dest_file: FileInfo): + def compare_file_updates(self, src_file: FileInfo, dest_file: FileInfo): """Compares time delta between src and dest files. :returns: True if the file does not need updating based on time of @@ -59,7 +70,6 @@ def compare_time(self, src_file: FileInfo, dest_file: FileInfo): delta = dest_time - src_time # pylint: disable=R1703,R1705 if timedelta.total_seconds(delta) >= 0: - # Destination is newer than source. return True else: # Destination is older than source, so diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py b/src/sagemaker/jumpstart/curated_hub/accessors/synctask.py similarity index 76% rename from src/sagemaker/jumpstart/curated_hub/accessors/sync.py rename to src/sagemaker/jumpstart/curated_hub/accessors/synctask.py index b724f5daf0..aa347a905a 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/sync.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/synctask.py @@ -17,7 +17,7 @@ from botocore.compat import six from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation -from sagemaker.jumpstart.curated_hub.accessors.synccomparator import SizeAndLastUpdatedComparator +from sagemaker.jumpstart.curated_hub.accessors.synccomparator import BaseComparator from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo @@ -25,8 +25,8 @@ @dataclass -class FileSyncResult: - """File Sync Result class""" +class SyncTaskInfo: + """Sync Task Info class""" files: List[FileInfo] destination: S3ObjectLocation @@ -34,35 +34,50 @@ class FileSyncResult: def __init__( self, files_to_copy: Generator[FileInfo, FileInfo, FileInfo], destination: S3ObjectLocation ): + """Contains information required to sync data. + + Returns: + :var: files (List[FileInfo]): Files that shoudl be synced. + :var: destination (S3ObjectLocation): Location to which to sync the files. + """ self.files = list(files_to_copy) self.destination = destination -class FileSync: - """FileSync class.""" +class SyncTaskHandler: + """Generates a ``SyncTaskInfo`` which contains information required to sync data.""" def __init__( - self, src_files: List[FileInfo], dest_files: List[FileInfo], destination: S3ObjectLocation + self, + src_files: List[FileInfo], + dest_files: List[FileInfo], + destination: S3ObjectLocation, + comparator: BaseComparator, ): - """Instantiates a ``FileSync`` class. Sorts src and dest files by name for comparisons. + """Instantiates a ``SyncTaskGenerator`` class. Args: - src_files (List[FileInfo]): List of files to sync with destination + src_files (List[FileInfo]): List of files to sync to destination bucket dest_files (List[FileInfo]): List of files already in destination bucket - dest_bucket (str): Destination bucket name for copied data + destination (S3ObjectLocation): S3 destination for copied data + + Returns: + ``SyncTaskInfo`` class containing: + :var: files (List[FileInfo]): Files that shoudl be synced. + :var: destination (S3ObjectLocation): Location to which to sync the files. """ - self.comparator = SizeAndLastUpdatedComparator() + self.comparator = comparator self.src_files: List[FileInfo] = sorted(src_files, key=lambda x: x.location.key) self.dest_files: List[FileInfo] = sorted(dest_files, key=lambda x: x.location.key) self.destination = destination - def call(self) -> FileSyncResult: - """Determines which files to copy based on the comparator. + def create(self) -> SyncTaskInfo: + """Creates a ``SyncTaskInfo`` object, which contains `files` to copy and the `destination` - Returns a ``FileSyncResult`` object. + Based on the `s3:sync` algorithm. """ files_to_copy = self._determine_files_to_copy() - return FileSyncResult(files_to_copy, self.destination) + return SyncTaskInfo(files_to_copy, self.destination) def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: """This function performs the actual comparisons. Returns a list of FileInfo to copy. diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 29c3f05289..96b2bf827c 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -23,13 +23,11 @@ from packaging.version import Version from sagemaker.jumpstart import utils -from sagemaker.jumpstart.curated_hub.accessors.filegenerator import ( - ModelSpecsFileGenerator, - S3PathFileGenerator, -) +from sagemaker.jumpstart.curated_hub.accessors import filegenerator from sagemaker.jumpstart.curated_hub.accessors.multipartcopy import MultiPartCopyHandler from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation -from sagemaker.jumpstart.curated_hub.accessors.sync import FileSync +from sagemaker.jumpstart.curated_hub.accessors.synccomparator import SizeAndLastUpdatedComparator +from sagemaker.jumpstart.curated_hub.accessors.synctask import SyncTaskHandler from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER @@ -157,7 +155,10 @@ def delete(self) -> None: return self._sagemaker_session.delete_hub(self.hub_name) def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool: - """Determines if input args to ``sync`` is correct.""" + """Determines if input args to ``sync`` is correct. + + `model_list` objects must have `model_id` (str) and optional `version` (str). + """ for obj in model_list: if not isinstance(obj.get("model_id"), str): return True @@ -165,35 +166,21 @@ def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool return True return False - def sync(self, model_list: List[Dict[str, str]]): - """Syncs a list of JumpStart model ids and versions with a CuratedHub + def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str]: + """Retrieves the lastest version of a model that has passed a wildcard ('*'). - Args: - model_list (List[Dict[str, str]]): List of `{ model_id: str, version: Optional[str] }` - objects that should be synced into the Hub. + Returns model ({ model_id: str, version: str }) """ - if self._is_invalid_model_list_input(model_list): - raise ValueError( - "Model list should be a list of objects with values 'model_id',", - "and optional 'version'.", - ) + model_specs = utils.verify_model_region_and_return_specs( + model["model_id"], "*", JumpStartScriptScope.INFERENCE, self.region + ) + model["version"] = model_specs.version + return model - # Fetch required information - # self._get_studio_manifest_map() + def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: + """Returns list of `HubContent` that have been created from a JumpStart model.""" hub_models = self.list_models() - # Retrieve latest version of unspecified JumpStart model versions - model_version_list = [] - for model in model_list: - version = model.get("version", "*") - if not version or version == "*": - model_specs = utils.verify_model_region_and_return_specs( - model["model_id"], version, JumpStartScriptScope.INFERENCE, self.region - ) - model["version"] = model_specs.version - model_version_list.append(model) - - # Find synced JumpStart model versions in the Hub js_models_in_hub = [] for hub_model in hub_models["HubContentSummaries"]: # TODO: extract both in one pass @@ -217,13 +204,25 @@ def sync(self, model_list: List[Dict[str, str]]): if jumpstart_model_id and jumpstart_model_version: js_models_in_hub.append(hub_model) - # Match inputted list of model versions with synced JumpStart model versions in the Hub + return js_models_in_hub + + def _determine_models_to_sync(self, model_list, models_in_hub) -> List[Dict[str, str]]: + """Determines which models from `sync` params to sync into the CuratedHub. + + Algorithm: + + First, look for a match of model name in Hub. If no model is found, sync that model. + + Next, compare versions of model to sync and what's in the Hub. If version already + in Hub, don't sync. If newer version in Hub, don't sync. If older version in Hub, + sync that model. + """ models_to_sync = [] - for model in model_version_list: + for model in model_list: matched_model = next( ( hub_model - for hub_model in js_models_in_hub + for hub_model in models_in_hub if hub_model and hub_model["name"] == model["model_id"] ), None, @@ -251,9 +250,36 @@ def sync(self, model_list: List[Dict[str, str]]): # Check minSDKVersion against current SDK version, emit log models_to_sync.append(model) + return models_to_sync + + def sync(self, model_list: List[Dict[str, str]]): + """Syncs a list of JumpStart model ids and versions with a CuratedHub + + Args: + model_list (List[Dict[str, str]]): List of `{ model_id: str, version: Optional[str] }` + objects that should be synced into the Hub. + """ + if self._is_invalid_model_list_input(model_list): + raise ValueError( + "Model list should be a list of objects with values 'model_id',", + "and optional 'version'.", + ) + + # Retrieve latest version of unspecified JumpStart model versions + model_version_list = [] + for model in model_list: + version = model.get("version", "*") + if version == "*": + model = self._populate_latest_model_version(model) + model_version_list.append(model) + + js_models_in_hub = self._get_jumpstart_models_in_hub() + + models_to_sync = self._determine_models_to_sync(model_version_list, js_models_in_hub) + # Delete old models? - # Copy content workflow + `SageMaker:ImportHubContent` for each model-to-sync in parallel + # CopyContentWorkflow + `SageMaker:ImportHubContent` for each model-to-sync in parallel tasks: List[futures.Future] = [] with futures.ThreadPoolExecutor( max_workers=self._default_thread_pool_size, @@ -300,26 +326,29 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): bucket=self.hub_bucket_name, key=f"{model_name}/{model_version}" ) - src_files = ModelSpecsFileGenerator(self.region, self._s3_client, studio_specs).format( - model_specs + src_files = filegenerator.generate_file_infos_from_model_specs( + model_specs, studio_specs, self.region, self._s3_client + ) + dest_files = filegenerator.generate_file_infos_from_s3_location( + dest_location, self._s3_client ) - dest_files = S3PathFileGenerator(self.region, self._s3_client).format(dest_location) - sync_result = FileSync(src_files, dest_files, dest_location).call() + comparator = SizeAndLastUpdatedComparator() + sync_task = SyncTaskHandler(src_files, dest_files, dest_location, comparator).create() - if len(sync_result.files) > 0: + if len(sync_task.files) > 0: MultiPartCopyHandler( - region=self.region, files=sync_result.files, dest_location=sync_result.destination - ).call() + region=self.region, files=sync_task.files, dest_location=sync_task.destination + ).execute() else: - JUMPSTART_LOGGER.warning("[%s/%s] Nothing to copy", model_name, model_version) + JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model_name, model_version) - # Tag model if specs say it is deprecated or training/inference vulnerable - # Update tag of HubContent ARN without version. Versioned ARNs are not - # onboarded to Tagris. + # TODO: Tag model if specs say it is deprecated or training/inference + # vulnerable. Update tag of HubContent ARN without version. + # Versioned ARNs are not onboarded to Tagris. tags = [] - hub_content_document = HubContentDocument_v2(spec=model_specs).__str__() + hub_content_document = str(HubContentDocument_v2(spec=model_specs)) self._sagemaker_session.import_hub_content( document_schema_version=HubContentDocument_v2.SCHEMA_VERSION, @@ -340,7 +369,7 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any model_id = model_specs.model_id model_version = model_specs.version - key = f"studio_models/{model_id}/studio_specs_v{model_version}.json" + key = utils.generate_studio_spec_file_prefix(model_id, model_version) response = self._s3_client.get_object( Bucket=utils.get_jumpstart_content_bucket(self.region), Key=key ) diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index c425d05485..ac01da45ca 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -19,7 +19,6 @@ from sagemaker.jumpstart.types import ( HubContentType, HubArnExtractedInfo, - JumpStartModelSpecs, ) from sagemaker.jumpstart import constants @@ -153,8 +152,3 @@ def create_hub_bucket_if_it_does_not_exist( ) return bucket_name - - -def get_model_framework(model_specs: JumpStartModelSpecs) -> str: - """Retrieves the model framework from a model spec""" - return model_specs.model_id.split("-")[0] diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9c457a5626..cd10a7123b 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -964,6 +964,10 @@ def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported + def get_framework(self) -> str: + """Returns the framework for the model.""" + return self.model_id.split("-")[0] + class JumpStartVersionedModelId(JumpStartDataHolderType): """Data class for versioned model IDs.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1e2bb11d45..f7375b3027 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -829,3 +829,8 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version + + +def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str: + """Returns the Studio Spec file prefix given a model ID and version.""" + return f"studio_models/{model_id}/studio_specs_v{model_version}.json" diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index 0b54d694e2..4469656385 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -14,8 +14,8 @@ import pytest from unittest.mock import Mock, patch from sagemaker.jumpstart.curated_hub.accessors.filegenerator import ( - ModelSpecsFileGenerator, - S3PathFileGenerator, + generate_file_infos_from_model_specs, + generate_file_infos_from_s3_location, ) from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo @@ -49,8 +49,7 @@ def test_s3_path_file_generator_happy_path(s3_client): } mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") - generator = S3PathFileGenerator("us-west-2", s3_client) - response = generator.format(mock_hub_bucket) + response = generate_file_infos_from_s3_location(mock_hub_bucket, s3_client) s3_client.list_objects_v2.assert_called_once() assert response == [ @@ -65,8 +64,7 @@ def test_model_specs_file_generator_happy_path(patched_get_model_specs, s3_clien specs = JumpStartModelSpecs(BASE_SPEC) studio_specs = {"defaultDataKey": "model_id123"} - generator = ModelSpecsFileGenerator("us-west-2", s3_client, studio_specs) - response = generator.format(specs) + response = generate_file_infos_from_model_specs(specs, studio_specs, "us-west-2", s3_client) s3_client.head_object.assert_called() patched_get_model_specs.assert_called() @@ -116,8 +114,7 @@ def test_s3_path_file_generator_with_no_objects(s3_client): s3_client.list_objects_v2.return_value = {"Contents": []} mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") - generator = S3PathFileGenerator("us-west-2", s3_client) - response = generator.format(mock_hub_bucket) + response = generate_file_infos_from_s3_location(mock_hub_bucket, s3_client) s3_client.list_objects_v2.assert_called_once() assert response == [] @@ -127,8 +124,7 @@ def test_s3_path_file_generator_with_no_objects(s3_client): s3_client.list_objects_v2.return_value = {} mock_hub_bucket = S3ObjectLocation(bucket="mock-bucket", key="mock-key") - generator = S3PathFileGenerator("us-west-2", s3_client) - response = generator.format(mock_hub_bucket) + response = generate_file_infos_from_s3_location(mock_hub_bucket, s3_client) s3_client.list_objects_v2.assert_called_once() assert response == [] From 297d1b64e0edc373a984f8389e43a004627838e3 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 7 Mar 2024 16:02:47 +0000 Subject: [PATCH 08/15] reformatting folders --- .../{filegenerator.py => file_generator.py} | 16 +-- .../curated_hub/accessors/fileinfo.py | 53 ---------- .../curated_hub/accessors/multipartcopy.py | 18 ++-- .../curated_hub/accessors/objectlocation.py | 47 --------- .../accessors/public_model_data.py | 10 +- .../jumpstart/curated_hub/constants.py | 19 ++++ .../jumpstart/curated_hub/curated_hub.py | 97 +++++++++++++------ .../synccomparator.py => sync/comparator.py} | 2 +- .../synctask.py => sync/request.py} | 38 +++++--- src/sagemaker/jumpstart/curated_hub/types.py | 76 ++++++++++++++- .../jumpstart/curated_hub/test_curated_hub.py | 23 ++++- .../curated_hub/test_filegenerator.py | 5 +- .../jumpstart/curated_hub/test_sync.py | 0 .../curated_hub/test_synccomparator.py | 4 +- tests/unit/sagemaker/jumpstart/test_types.py | 5 +- 15 files changed, 230 insertions(+), 183 deletions(-) rename src/sagemaker/jumpstart/curated_hub/accessors/{filegenerator.py => file_generator.py} (89%) delete mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py delete mode 100644 src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py create mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py rename src/sagemaker/jumpstart/curated_hub/{accessors/synccomparator.py => sync/comparator.py} (97%) rename src/sagemaker/jumpstart/curated_hub/{accessors/synctask.py => sync/request.py} (78%) delete mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_sync.py diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py similarity index 89% rename from src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py rename to src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py index b4612f7b08..48918ccf24 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py @@ -17,8 +17,11 @@ from datetime import datetime from botocore.client import BaseClient -from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType -from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation +from sagemaker.jumpstart.curated_hub.types import ( + FileInfo, + HubContentDependencyType, + S3ObjectLocation, +) from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor from sagemaker.jumpstart.types import JumpStartModelSpecs @@ -63,10 +66,10 @@ def generate_file_infos_from_model_specs( files = [] for dependency in HubContentDependencyType: location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) + parameters = {"Bucket": location.bucket, "Prefix": location.key} + location_type = "prefix" if location.key.endswith("/") else "object" - # Prefix - if location.key.endswith("/"): - parameters = {"Bucket": location.bucket, "Prefix": location.key} + if location_type == "prefix": response = s3_client.list_objects_v2(**parameters) contents = response.get("Contents", None) for s3_obj in contents: @@ -83,8 +86,7 @@ def generate_file_infos_from_model_specs( dependency_type, ) ) - else: - parameters = {"Bucket": location.bucket, "Key": location.key} + elif location_type == "object": response = s3_client.head_object(**parameters) size: bytes = response.get("ContentLength", None) last_updated: datetime = response.get("LastModified", None) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py b/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py deleted file mode 100644 index ea418aebad..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module contains important details related to HubContent data files.""" -from __future__ import absolute_import - -from enum import Enum -from dataclasses import dataclass -from typing import Optional -from datetime import datetime - -from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation - - -class HubContentDependencyType(str, Enum): - """Enum class for HubContent dependency names""" - - INFERENCE_ARTIFACT = "inference_artifact_s3_reference" - TRAINING_ARTIFACT = "training_artifact_s3_reference" - INFERENCE_SCRIPT = "inference_script_s3_reference" - TRAINING_SCRIPT = "training_script_s3_reference" - DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference" - DEMO_NOTEBOOK = "demo_notebook_s3_reference" - MARKDOWN = "markdown_s3_reference" - - -@dataclass -class FileInfo: - """Data class for additional S3 file info.""" - - location: S3ObjectLocation - - def __init__( - self, - bucket: str, - key: str, - size: Optional[bytes], - last_updated: Optional[datetime], - dependecy_type: Optional[HubContentDependencyType] = None, - ): - self.location = S3ObjectLocation(bucket, key) - self.size = size - self.last_updated = last_updated - self.dependecy_type = dependecy_type diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py index 029b6f1f9e..fdd7fe8334 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -12,15 +12,14 @@ # language governing permissions and limitations under the License. """This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" from __future__ import absolute_import -from typing import List import boto3 import botocore import tqdm from sagemaker.jumpstart.constants import JUMPSTART_LOGGER -from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo -from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation +from sagemaker.jumpstart.curated_hub.types import FileInfo +from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequest s3transfer = boto3.s3.transfer @@ -67,13 +66,12 @@ class MultiPartCopyHandler(object): def __init__( self, region: str, - files: List[FileInfo], - dest_location: S3ObjectLocation, + sync_request: HubSyncRequest, ): """Something.""" self.region = region - self.files = files - self.dest_location = dest_location + self.files = sync_request.files + self.dest_location = sync_request.dest_location config = botocore.config.Config(max_pool_connections=self.WORKERS) self.s3_client = boto3.client("s3", region_name=self.region, config=config) @@ -88,7 +86,7 @@ def __init__( client=self.s3_client, config=transfer_config ) - def _copy_file(self, file: FileInfo, update_fn): + def _copy_file(self, file: FileInfo, progress_cb): """Something.""" copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} result = self.transfer_manager.copy( @@ -96,9 +94,10 @@ def _copy_file(self, file: FileInfo, update_fn): key=f"{self.dest_location.key}/{file.location.key}", copy_source=copy_source, subscribers=[ - s3transfer.ProgressCallbackInvoker(update_fn), + s3transfer.ProgressCallbackInvoker(progress_cb), ], ) + # Attempt to access result to throw error if exists. Silently calls if successful. result.result() def execute(self): @@ -124,5 +123,6 @@ def execute(self): for file in self.files: self._copy_file(file, progress.update) + # Call `shutdown` to wait for copy results self.transfer_manager.shutdown() progress.close() diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py b/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py deleted file mode 100644 index bb72152ee1..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module utilites to assist S3 client calls for the Curated Hub.""" -from __future__ import absolute_import -from dataclasses import dataclass -from typing import Dict - -from sagemaker.s3_utils import parse_s3_url - - -@dataclass -class S3ObjectLocation: - """Helper class for S3 object references""" - - bucket: str - key: str - - def format_for_s3_copy(self) -> Dict[str, str]: - """Returns a dict formatted for S3 copy calls""" - return { - "Bucket": self.bucket, - "Key": self.key, - } - - def get_uri(self) -> str: - """Returns the s3 URI""" - return f"s3://{self.bucket}/{self.key}" - - -def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: - """Utiity to help generate an S3 object reference""" - bucket, key = parse_s3_url(s3_uri) - - return S3ObjectLocation( - bucket=bucket, - key=key, - ) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index 6bd4d856a3..8712892af9 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -14,14 +14,14 @@ from __future__ import absolute_import from typing import Dict, Any from sagemaker import model_uris, script_uris -from sagemaker.jumpstart.curated_hub.accessors.fileinfo import HubContentDependencyType -from sagemaker.jumpstart.enums import JumpStartScriptScope -from sagemaker.jumpstart.types import JumpStartModelSpecs -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket -from sagemaker.jumpstart.curated_hub.accessors.objectlocation import ( +from sagemaker.jumpstart.curated_hub.types import ( + HubContentDependencyType, S3ObjectLocation, create_s3_object_reference_from_uri, ) +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.types import JumpStartModelSpecs +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket class PublicModelDataAccessor: diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py new file mode 100644 index 0000000000..5d35aa80c6 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/constants.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import + +JUMPSTART_HUB_MODEL_ID_TAG_PREFIX = "@jumpstart-model-id" +JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX = "@jumpstart-model-version" +FRAMEWORK_TAG_PREFIX = "@framework" +TASK_TAG_PREFIX = "@mltask" diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 96b2bf827c..0ebe31dd0d 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -23,11 +23,16 @@ from packaging.version import Version from sagemaker.jumpstart import utils -from sagemaker.jumpstart.curated_hub.accessors import filegenerator +from sagemaker.jumpstart.curated_hub.accessors import file_generator from sagemaker.jumpstart.curated_hub.accessors.multipartcopy import MultiPartCopyHandler -from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation -from sagemaker.jumpstart.curated_hub.accessors.synccomparator import SizeAndLastUpdatedComparator -from sagemaker.jumpstart.curated_hub.accessors.synctask import SyncTaskHandler +from sagemaker.jumpstart.curated_hub.constants import ( + JUMPSTART_HUB_MODEL_ID_TAG_PREFIX, + JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX, + TASK_TAG_PREFIX, + FRAMEWORK_TAG_PREFIX, +) +from sagemaker.jumpstart.curated_hub.sync.comparator import SizeAndLastUpdatedComparator +from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequestFactory from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER @@ -41,7 +46,11 @@ create_hub_bucket_if_it_does_not_exist, generate_default_hub_bucket_name, ) -from sagemaker.jumpstart.curated_hub.types import HubContentDocument_v2 +from sagemaker.jumpstart.curated_hub.types import ( + HubContentDocument_v2, + JumpStartModelInfo, + S3ObjectLocation, +) class CuratedHub: @@ -78,9 +87,20 @@ def _fetch_hub_bucket_name(self) -> str: hub_bucket_prefix = hub_response["S3StorageConfig"].get("S3OutputPath", None) if hub_bucket_prefix: return hub_bucket_prefix.replace("s3://", "") - return generate_default_hub_bucket_name(self._sagemaker_session) + default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + default_bucket_name, + ) + return default_bucket_name except exceptions.ClientError: hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + default_bucket_name, + ) return hub_bucket_name def create( @@ -188,7 +208,7 @@ def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: ( tag for tag in hub_model["search_keywords"] - if tag.startswith("@jumpstart-model-id") + if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX) ), None, ) @@ -196,7 +216,7 @@ def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: ( tag for tag in hub_model["search_keywords"] - if tag.startswith("@jumpstart-model-version") + if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX) ), None, ) @@ -206,7 +226,9 @@ def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: return js_models_in_hub - def _determine_models_to_sync(self, model_list, models_in_hub) -> List[Dict[str, str]]: + def _determine_models_to_sync( + self, model_list: List[JumpStartModelInfo], models_in_hub + ) -> List[JumpStartModelInfo]: """Determines which models from `sync` params to sync into the CuratedHub. Algorithm: @@ -223,7 +245,7 @@ def _determine_models_to_sync(self, model_list, models_in_hub) -> List[Dict[str, ( hub_model for hub_model in models_in_hub - if hub_model and hub_model["name"] == model["model_id"] + if hub_model and hub_model["name"] == model.model_id ), None, ) @@ -233,7 +255,7 @@ def _determine_models_to_sync(self, model_list, models_in_hub) -> List[Dict[str, models_to_sync.append(model) if matched_model: - model_version = Version(model["version"]) + model_version = Version(model.version) hub_model_version = Version(matched_model["version"]) # 1. Model version exists in Hub, pass @@ -271,11 +293,19 @@ def sync(self, model_list: List[Dict[str, str]]): version = model.get("version", "*") if version == "*": model = self._populate_latest_model_version(model) - model_version_list.append(model) + JUMPSTART_LOGGER.warning( + "No version specified for model %s. Using version %s", + model["model_id"], + model["version"], + ) + model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"])) js_models_in_hub = self._get_jumpstart_models_in_hub() models_to_sync = self._determine_models_to_sync(model_version_list, js_models_in_hub) + JUMPSTART_LOGGER.warning( + "Syncing the following models into Hub %s: %s", self.hub_name, models_to_sync + ) # Delete old models? @@ -308,14 +338,11 @@ def sync(self, model_list: List[Dict[str, str]]): f"Failures when importing models to curated hub in parallel: {failed_imports}" ) - def _sync_public_model_to_hub(self, model: Dict[str, str]): + def _sync_public_model_to_hub(self, model: JumpStartModelInfo): """Syncs a public JumpStart model version to the Hub. Runs in parallel.""" - model_name = model["model_id"] - model_version = model["version"] - model_specs = utils.verify_model_region_and_return_specs( - model_id=model_name, - version=model_version, + model_id=model.model_id, + version=model.version, region=self.region, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self._sagemaker_session, @@ -323,49 +350,55 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]): studio_specs = self._fetch_studio_specs(model_specs=model_specs) dest_location = S3ObjectLocation( - bucket=self.hub_bucket_name, key=f"{model_name}/{model_version}" + bucket=self.hub_bucket_name, key=f"{model.model_id}/{model.version}" ) - - src_files = filegenerator.generate_file_infos_from_model_specs( + src_files = file_generator.generate_file_infos_from_model_specs( model_specs, studio_specs, self.region, self._s3_client ) - dest_files = filegenerator.generate_file_infos_from_s3_location( + dest_files = file_generator.generate_file_infos_from_s3_location( dest_location, self._s3_client ) comparator = SizeAndLastUpdatedComparator() - sync_task = SyncTaskHandler(src_files, dest_files, dest_location, comparator).create() + sync_request = HubSyncRequestFactory( + src_files, dest_files, dest_location, comparator + ).create() - if len(sync_task.files) > 0: - MultiPartCopyHandler( - region=self.region, files=sync_task.files, dest_location=sync_task.destination - ).execute() + if len(sync_request.files) > 0: + MultiPartCopyHandler(region=self.region, sync_request=sync_request).execute() else: - JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model_name, model_version) + JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version) # TODO: Tag model if specs say it is deprecated or training/inference # vulnerable. Update tag of HubContent ARN without version. # Versioned ARNs are not onboarded to Tagris. tags = [] + search_keywords = [ + f"{JUMPSTART_HUB_MODEL_ID_TAG_PREFIX}:{model.model_id}", + f"{JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX}:{model.version}", + f"{FRAMEWORK_TAG_PREFIX}:{model_specs.get_framework()}", + f"{TASK_TAG_PREFIX}:TODO: pull from specs", + ] + hub_content_document = str(HubContentDocument_v2(spec=model_specs)) self._sagemaker_session.import_hub_content( document_schema_version=HubContentDocument_v2.SCHEMA_VERSION, - hub_content_name=model_name, - hub_content_version=model_version, + hub_content_name=model.model_id, + hub_content_version=model.version, hub_name=self.hub_name, hub_content_document=hub_content_document, hub_content_type=HubContentType.MODEL, hub_content_display_name="", hub_content_description="", hub_content_markdown="", - hub_content_search_keywords=[], + hub_content_search_keywords=search_keywords, tags=tags, ) def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any]: - """Fetches StudioSpecs given a models' SDK Specs.""" + """Fetches StudioSpecs given a model's SDK Specs.""" model_id = model_specs.model_id model_version = model_specs.version diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py b/src/sagemaker/jumpstart/curated_hub/sync/comparator.py similarity index 97% rename from src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py rename to src/sagemaker/jumpstart/curated_hub/sync/comparator.py index b5dfad85aa..ed7c1b9269 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py +++ b/src/sagemaker/jumpstart/curated_hub/sync/comparator.py @@ -15,7 +15,7 @@ from datetime import timedelta from sagemaker.jumpstart.constants import JUMPSTART_LOGGER -from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo +from sagemaker.jumpstart.curated_hub.types import FileInfo class BaseComparator: diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/synctask.py b/src/sagemaker/jumpstart/curated_hub/sync/request.py similarity index 78% rename from src/sagemaker/jumpstart/curated_hub/accessors/synctask.py rename to src/sagemaker/jumpstart/curated_hub/sync/request.py index aa347a905a..25f9fb72f4 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/synctask.py +++ b/src/sagemaker/jumpstart/curated_hub/sync/request.py @@ -16,17 +16,16 @@ from typing import Generator, List from botocore.compat import six -from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation -from sagemaker.jumpstart.curated_hub.accessors.synccomparator import BaseComparator -from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo +from sagemaker.jumpstart.curated_hub.sync.comparator import BaseComparator +from sagemaker.jumpstart.curated_hub.types import FileInfo, S3ObjectLocation advance_iterator = six.advance_iterator @dataclass -class SyncTaskInfo: - """Sync Task Info class""" +class HubSyncRequest: + """HubSyncRequest class""" files: List[FileInfo] destination: S3ObjectLocation @@ -34,7 +33,7 @@ class SyncTaskInfo: def __init__( self, files_to_copy: Generator[FileInfo, FileInfo, FileInfo], destination: S3ObjectLocation ): - """Contains information required to sync data. + """Contains information required to sync data into a Hub. Returns: :var: files (List[FileInfo]): Files that shoudl be synced. @@ -44,8 +43,8 @@ def __init__( self.destination = destination -class SyncTaskHandler: - """Generates a ``SyncTaskInfo`` which contains information required to sync data.""" +class HubSyncRequestFactory: + """Generates a ``HubSyncRequest`` which is required to sync data into a Hub.""" def __init__( self, @@ -54,7 +53,7 @@ def __init__( destination: S3ObjectLocation, comparator: BaseComparator, ): - """Instantiates a ``SyncTaskGenerator`` class. + """Instantiates a ``HubSyncRequestFactory`` class. Args: src_files (List[FileInfo]): List of files to sync to destination bucket @@ -62,22 +61,23 @@ def __init__( destination (S3ObjectLocation): S3 destination for copied data Returns: - ``SyncTaskInfo`` class containing: + ``HubSyncRequest`` class containing: :var: files (List[FileInfo]): Files that shoudl be synced. :var: destination (S3ObjectLocation): Location to which to sync the files. """ self.comparator = comparator + # Need the file lists to be sorted for comparisons below self.src_files: List[FileInfo] = sorted(src_files, key=lambda x: x.location.key) self.dest_files: List[FileInfo] = sorted(dest_files, key=lambda x: x.location.key) self.destination = destination - def create(self) -> SyncTaskInfo: - """Creates a ``SyncTaskInfo`` object, which contains `files` to copy and the `destination` + def create(self) -> HubSyncRequest: + """Creates a ``HubSyncRequest`` object, which contains `files` to copy and the `destination` Based on the `s3:sync` algorithm. """ files_to_copy = self._determine_files_to_copy() - return SyncTaskInfo(files_to_copy, self.destination) + return HubSyncRequest(files_to_copy, self.destination) def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: """This function performs the actual comparisons. Returns a list of FileInfo to copy. @@ -89,6 +89,10 @@ def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: If there is a dest file, compare file names. If the file names are equivalent, use the comparator to see if the dest file should be updated. If the file names are not equivalent, increment the dest pointer. + + Notes: + Increment src_files = continue, Increment dest_files = advance_iterator(iterator), + Take src_file = yield """ # :var dest_done: True if there are no files from the dest left. dest_done = False @@ -119,7 +123,9 @@ def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: dest_done = True continue - # Past the src file alphabetically in dest file list. Take the src file and continue + # Past the src file alphabetically in dest file list. Take the src file and increment src_files. + # If there is an alpha-smaller file name in dest as compared to src, it means there is an + # unexpected file in dest. Do nothing and continue to the next src_file if self._is_alphabetically_larger_file_name( src_file.location.key, dest_file.location.key ): @@ -127,9 +133,9 @@ def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: continue def _is_same_file_name(self, src_filename: str, dest_filename: str) -> bool: - """Compares two file names and determiens if they are the same. + """Determines if two files have the same base path and file name. - Destination files might add a prefix. + Destination files might have a prefix, so account for that. """ return dest_filename.endswith(src_filename) diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index b788bc1d7c..35cbbc4614 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -10,11 +10,83 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module stores types related to SageMaker JumpStart.""" +"""This module stores types related to SageMaker JumpStart CuratedHub.""" from __future__ import absolute_import -from typing import Dict, Any +from typing import Dict, Any, Optional +from enum import Enum +from dataclasses import dataclass +from datetime import datetime from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs +from sagemaker.s3_utils import parse_s3_url + + +@dataclass +class S3ObjectLocation: + """Helper class for S3 object references""" + + bucket: str + key: str + + def format_for_s3_copy(self) -> Dict[str, str]: + """Returns a dict formatted for S3 copy calls""" + return { + "Bucket": self.bucket, + "Key": self.key, + } + + def get_uri(self) -> str: + """Returns the s3 URI""" + return f"s3://{self.bucket}/{self.key}" + + +def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: + """Utiity to help generate an S3 object reference""" + bucket, key = parse_s3_url(s3_uri) + + return S3ObjectLocation( + bucket=bucket, + key=key, + ) + + +@dataclass +class JumpStartModelInfo: + """Helper class for storing JumpStart model info.""" + + model_id: str + version: Optional[str] = None + + +class HubContentDependencyType(str, Enum): + """Enum class for HubContent dependency names""" + + INFERENCE_ARTIFACT = "inference_artifact_s3_reference" + TRAINING_ARTIFACT = "training_artifact_s3_reference" + INFERENCE_SCRIPT = "inference_script_s3_reference" + TRAINING_SCRIPT = "training_script_s3_reference" + DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference" + DEMO_NOTEBOOK = "demo_notebook_s3_reference" + MARKDOWN = "markdown_s3_reference" + + +class FileInfo(JumpStartDataHolderType): + """Data class for additional S3 file info.""" + + location: S3ObjectLocation + + def __init__( + self, + bucket: str, + key: str, + size: Optional[bytes], + last_updated: Optional[datetime], + dependecy_type: Optional[HubContentDependencyType] = None, + ): + self.location = S3ObjectLocation(bucket, key) + self.size = size + self.last_updated = last_updated + self.dependecy_type = dependecy_type class HubContentDocument_v2(JumpStartDataHolderType): diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index d9c4a37e94..52c284fe9e 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -16,6 +16,7 @@ import pytest from mock import Mock from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub +from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec REGION = "us-east-1" @@ -157,7 +158,12 @@ def test_sync_kicks_off_parallel_syncs( hub.sync([model_one, model_two]) - mock_sync_public_models.assert_has_calls([mock.call(model_one), mock.call(model_two)]) + mock_sync_public_models.assert_has_calls( + [ + mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*")), + mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")), + ] + ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -196,7 +202,9 @@ def test_sync_filters_models_that_exist_in_hub( hub.sync([model_one, model_two]) - mock_sync_public_models.assert_called_once_with(model_one) + mock_sync_public_models.assert_called_once_with( + JumpStartModelInfo("mock-model-one-huggingface", "*") + ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -239,7 +247,12 @@ def test_sync_updates_old_models_in_hub( hub.sync([model_one, model_two]) - mock_sync_public_models.assert_has_calls([mock.call(model_one), mock.call(model_two)]) + mock_sync_public_models.assert_has_calls( + [ + mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*")), + mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")), + ] + ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -282,4 +295,6 @@ def test_sync_passes_newer_hub_models( hub.sync([model_one, model_two]) - mock_sync_public_models.assert_called_once_with(model_one) + mock_sync_public_models.assert_called_once_with( + JumpStartModelInfo("mock-model-one-huggingface", "*") + ) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py index 4469656385..accd2a5c8d 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py @@ -13,13 +13,12 @@ from __future__ import absolute_import import pytest from unittest.mock import Mock, patch -from sagemaker.jumpstart.curated_hub.accessors.filegenerator import ( +from sagemaker.jumpstart.curated_hub.accessors.file_generator import ( generate_file_infos_from_model_specs, generate_file_infos_from_s3_location, ) -from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo +from sagemaker.jumpstart.curated_hub.types import FileInfo, S3ObjectLocation -from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation from sagemaker.jumpstart.types import JumpStartModelSpecs from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_sync.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_sync.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py index b5e1654f05..4753820320 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_synccomparator.py @@ -13,9 +13,9 @@ from __future__ import absolute_import import unittest from datetime import datetime -from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo +from sagemaker.jumpstart.curated_hub.types import FileInfo -from sagemaker.jumpstart.curated_hub.accessors.synccomparator import SizeAndLastUpdatedComparator +from sagemaker.jumpstart.curated_hub.sync.comparator import SizeAndLastUpdatedComparator class SizeAndLastUpdateComparatorTest(unittest.TestCase): diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 82e69e1d89..7f842a053c 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -407,8 +407,9 @@ def test_jumpstart_model_specs(): assert specs1.to_json() == BASE_SPEC - BASE_SPEC["model_id"] = "diff model ID" - specs2 = JumpStartModelSpecs(BASE_SPEC) + diff_specs = copy.deepcopy(BASE_SPEC) + diff_specs["model_id"] = "diff model ID" + specs2 = JumpStartModelSpecs(diff_specs) assert specs1 != specs2 specs3 = copy.deepcopy(specs1) From 18a5728afd925c24f077908d43765e20a0d97e24 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 7 Mar 2024 18:09:12 +0000 Subject: [PATCH 09/15] testing for sync --- .../curated_hub/accessors/file_generator.py | 3 +- .../curated_hub/accessors/multipartcopy.py | 17 +- .../jumpstart/curated_hub/curated_hub.py | 5 +- .../jumpstart/curated_hub/sync/request.py | 29 ++- src/sagemaker/jumpstart/curated_hub/types.py | 2 +- .../jumpstart/curated_hub/test_curated_hub.py | 188 ++++++++++++++++++ .../jumpstart/curated_hub/test_syncrequest.py | 138 +++++++++++++ 7 files changed, 362 insertions(+), 20 deletions(-) create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py index 48918ccf24..3e1ac849cb 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py @@ -66,10 +66,10 @@ def generate_file_infos_from_model_specs( files = [] for dependency in HubContentDependencyType: location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) - parameters = {"Bucket": location.bucket, "Prefix": location.key} location_type = "prefix" if location.key.endswith("/") else "object" if location_type == "prefix": + parameters = {"Bucket": location.bucket, "Prefix": location.key} response = s3_client.list_objects_v2(**parameters) contents = response.get("Contents", None) for s3_obj in contents: @@ -87,6 +87,7 @@ def generate_file_infos_from_model_specs( ) ) elif location_type == "object": + parameters = {"Bucket": location.bucket, "Key": location.key} response = s3_client.head_object(**parameters) size: bytes = response.get("ContentLength", None) last_updated: datetime = response.get("LastModified", None) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py index fdd7fe8334..96f7b77c61 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -68,10 +68,16 @@ def __init__( region: str, sync_request: HubSyncRequest, ): - """Something.""" + """Multi-part S3:Copy Handler initializer. + + Args: + region (str): Region for the S3 Client + sync_request (HubSyncRequest): sync request object containing + information required to perform the copy + """ self.region = region self.files = sync_request.files - self.dest_location = sync_request.dest_location + self.dest_location = sync_request.destination config = botocore.config.Config(max_pool_connections=self.WORKERS) self.s3_client = boto3.client("s3", region_name=self.region, config=config) @@ -87,7 +93,7 @@ def __init__( ) def _copy_file(self, file: FileInfo, progress_cb): - """Something.""" + """Performs the actual MultiPart S3:Copy of the object.""" copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} result = self.transfer_manager.copy( bucket=self.dest_location.bucket, @@ -101,7 +107,10 @@ def _copy_file(self, file: FileInfo, progress_cb): result.result() def execute(self): - """Something.""" + """Executes the MultiPart S3:Copy on the class. + + Sets up progress bar and kicks off each copy request. + """ total_size = sum([file.size for file in self.files]) JUMPSTART_LOGGER.warning( "Copying %s files (%s) into %s/%s", diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 0ebe31dd0d..eccb02b4fd 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -187,15 +187,14 @@ def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool return False def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str]: - """Retrieves the lastest version of a model that has passed a wildcard ('*'). + """Populates the lastest version of a model from specs no matter what is passed. Returns model ({ model_id: str, version: str }) """ model_specs = utils.verify_model_region_and_return_specs( model["model_id"], "*", JumpStartScriptScope.INFERENCE, self.region ) - model["version"] = model_specs.version - return model + return {"model_id": model["model_id"], "version": model_specs.version} def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: """Returns list of `HubContent` that have been created from a JumpStart model.""" diff --git a/src/sagemaker/jumpstart/curated_hub/sync/request.py b/src/sagemaker/jumpstart/curated_hub/sync/request.py index 25f9fb72f4..5d3724da6e 100644 --- a/src/sagemaker/jumpstart/curated_hub/sync/request.py +++ b/src/sagemaker/jumpstart/curated_hub/sync/request.py @@ -66,10 +66,17 @@ def __init__( :var: destination (S3ObjectLocation): Location to which to sync the files. """ self.comparator = comparator + self.destination = destination # Need the file lists to be sorted for comparisons below self.src_files: List[FileInfo] = sorted(src_files, key=lambda x: x.location.key) - self.dest_files: List[FileInfo] = sorted(dest_files, key=lambda x: x.location.key) - self.destination = destination + formatted_dest_files = [self._format_dest_file(file) for file in dest_files] + self.dest_files: List[FileInfo] = sorted(formatted_dest_files, key=lambda x: x.location.key) + + def _format_dest_file(self, file: FileInfo) -> FileInfo: + """Strips HubContent data prefix from dest file name""" + formatted_key = file.location.key.replace(f"{self.destination.key}/", "") + file.location.key = formatted_key + return file def create(self) -> HubSyncRequest: """Creates a ``HubSyncRequest`` object, which contains `files` to copy and the `destination` @@ -106,11 +113,13 @@ def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: for src_file in self.src_files: # End of dest, yield remaining src_files if dest_done: + print("here????", src_file.location.key) yield src_file continue # We've identified two files that have the same name, further compare if self._is_same_file_name(src_file.location.key, dest_file.location.key): + print("wait i'm asdfd", src_file.location.key, dest_file.location.key) should_sync = self.comparator.determine_should_sync(src_file, dest_file) if should_sync: @@ -124,21 +133,19 @@ def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: continue # Past the src file alphabetically in dest file list. Take the src file and increment src_files. - # If there is an alpha-smaller file name in dest as compared to src, it means there is an + # If there is an alpha-larger file name in dest as compared to src, it means there is an # unexpected file in dest. Do nothing and continue to the next src_file - if self._is_alphabetically_larger_file_name( + if self._is_alphabetically_earlier_file_name( src_file.location.key, dest_file.location.key ): + print("wait i'm here", src_file.location.key, dest_file.location.key) yield src_file continue def _is_same_file_name(self, src_filename: str, dest_filename: str) -> bool: - """Determines if two files have the same base path and file name. - - Destination files might have a prefix, so account for that. - """ - return dest_filename.endswith(src_filename) + """Determines if two files have the same file name.""" + return src_filename == dest_filename - def _is_alphabetically_larger_file_name(self, src_filename: str, dest_filename: str) -> bool: + def _is_alphabetically_earlier_file_name(self, src_filename: str, dest_filename: str) -> bool: """Determines if one filename is alphabetically earlier than another.""" - return src_filename > dest_filename + return src_filename < dest_filename diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index 35cbbc4614..28c5f39300 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -55,7 +55,7 @@ class JumpStartModelInfo: """Helper class for storing JumpStart model info.""" model_id: str - version: Optional[str] = None + version: str class HubContentDependencyType(str, Enum): diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 52c284fe9e..5f86c25a65 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -11,12 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from copy import deepcopy from unittest import mock from unittest.mock import patch import pytest from mock import Mock from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo +from sagemaker.jumpstart.types import JumpStartModelSpecs +from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec REGION = "us-east-1" @@ -298,3 +301,188 @@ def test_sync_passes_newer_hub_models( mock_sync_public_models.assert_called_once_with( JumpStartModelInfo("mock-model-one-huggingface", "*") ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_populate_latest_model_version(mock_get_model_specs, sagemaker_session): + mock_get_model_specs.return_value = JumpStartModelSpecs(deepcopy(BASE_SPEC)) + + hub_name = "mock_hub_name" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + res = hub._populate_latest_model_version({"model_id": "mock-pytorch-model-one", "version": "*"}) + assert res == {"model_id": "mock-pytorch-model-one", "version": "1.0.0"} + + res = hub._populate_latest_model_version({"model_id": "mock-pytorch-model-one"}) + assert res == {"model_id": "mock-pytorch-model-one", "version": "1.0.0"} + + # Should take latest version from specs no matter what. Parent should responsibly call + res = hub._populate_latest_model_version( + {"model_id": "mock-pytorch-model-one", "version": "2.0.0"} + ) + assert res == {"model_id": "mock-pytorch-model-one", "version": "1.0.0"} + + +@patch(f"{MODULE_PATH}.list_models") +def test_get_jumpstart_models_in_hub(mock_list_models, sagemaker_session): + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.3", + ], + }, + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + } + + hub_name = "mock_hub_name" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + res = hub._get_jumpstart_models_in_hub() + assert res == [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.3", + ], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + + mock_list_models.return_value = {"HubContentSummaries": []} + + res = hub._get_jumpstart_models_in_hub() + assert res == [] + + mock_list_models.return_value = { + "HubContentSummaries": [ + { + "name": "mock-model-three-nonsense", + "version": "1.0.2", + "search_keywords": ["tag-one", "tag-two"], + }, + ] + } + + res = hub._get_jumpstart_models_in_hub() + assert res == [] + + +def test_determine_models_to_sync(sagemaker_session): + hub_name = "mock_hub_name" + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) + + js_models_in_hub = [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.1", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + model_one = JumpStartModelInfo("mock-model-one-huggingface", "1.2.3") + model_two = JumpStartModelInfo("mock-model-two-pytorch", "1.0.2") + # No model_one, older model_two + res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + assert res == [model_one, model_two] + + js_models_in_hub = [ + { + "name": "mock-model-two-pytorch", + "version": "1.0.3", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.3", + ], + }, + { + "name": "mock-model-four-huggingface", + "version": "2.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-four-huggingface", + "@jumpstart-model-version:2.0.2", + ], + }, + ] + # No model_one, newer model_two + res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + assert res == [model_one] + + js_models_in_hub = [ + { + "name": "mock-model-one-huggingface", + "version": "1.2.3", + "search_keywords": [ + "@jumpstart-model-id:model-one-huggingface", + "@jumpstart-model-version:1.2.3", + ], + }, + { + "name": "mock-model-two-pytorch", + "version": "1.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + ] + # Same model_one, same model_two + res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + assert res == [] + + js_models_in_hub = [ + { + "name": "mock-model-one-huggingface", + "version": "1.2.1", + "search_keywords": [ + "@jumpstart-model-id:model-one-huggingface", + "@jumpstart-model-version:1.2.1", + ], + }, + { + "name": "mock-model-two-pytorch", + "version": "1.0.2", + "search_keywords": [ + "@jumpstart-model-id:model-two-pytorch", + "@jumpstart-model-version:1.0.2", + ], + }, + ] + # Old model_one, same model_two + res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + assert res == [model_one] diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py new file mode 100644 index 0000000000..497f0f4d60 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_syncrequest.py @@ -0,0 +1,138 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from typing import List, Optional +from datetime import datetime + +import pytest +from sagemaker.jumpstart.curated_hub.sync.comparator import SizeAndLastUpdatedComparator + +from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequestFactory +from sagemaker.jumpstart.curated_hub.types import ( + FileInfo, + HubContentDependencyType, + S3ObjectLocation, +) + +COMPARATOR = SizeAndLastUpdatedComparator() + + +def _helper_generate_fileinfos( + num_infos: int, + bucket: Optional[str] = None, + key_prefix: Optional[str] = None, + size: Optional[int] = None, + last_updated: Optional[datetime] = None, + dependecy_type: Optional[HubContentDependencyType] = None, +) -> List[FileInfo]: + + file_infos = [] + for i in range(num_infos): + bucket = bucket or "default-bucket" + key_prefix = key_prefix or "mock-key" + size = size or 123456 + last_updated = last_updated or datetime.today() + + file_infos.append( + FileInfo( + bucket=bucket, + key=f"{key_prefix}-{i}", + size=size, + last_updated=last_updated, + dependecy_type={dependecy_type}, + ) + ) + return file_infos + + +@pytest.mark.parametrize( + ("src_files,dest_files"), + [ + pytest.param(_helper_generate_fileinfos(8), []), + pytest.param([], _helper_generate_fileinfos(8)), + pytest.param([], []), + ], +) +def test_sync_request_factory_edge_cases(src_files, dest_files): + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files + assert req.destination == dest_location + + +def test_passes_existing_files_in_dest(): + files = _helper_generate_fileinfos(4, key_prefix="aafile.py") + tarballs = _helper_generate_fileinfos(3, key_prefix="bb.tar.gz") + extra_files = _helper_generate_fileinfos(2, key_prefix="ccextrafiles.py") + + src_files = [*tarballs, *files, *extra_files] + dest_files = [files[1], files[2], tarballs[1]] + + expected_response = [files[0], files[3], tarballs[0], tarballs[2], *extra_files] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == expected_response + + +def test_adds_files_with_same_name_diff_size(): + file_one = _helper_generate_fileinfos(1, key_prefix="file.py", size=101010)[0] + file_two = _helper_generate_fileinfos(1, key_prefix="file.py", size=123456)[0] + + src_files = [file_one] + dest_files = [file_two] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files + + +def test_adds_files_with_same_name_dest_older_time(): + file_dest = _helper_generate_fileinfos(1, key_prefix="file.py", last_updated=datetime.today())[ + 0 + ] + file_src = _helper_generate_fileinfos(1, key_prefix="file.py", size=datetime.today())[0] + + src_files = [file_src] + dest_files = [file_dest] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files + + +def test_does_not_add_files_with_same_name_src_older_time(): + file_src = _helper_generate_fileinfos(1, key_prefix="file.py", last_updated=datetime.today())[0] + file_dest = _helper_generate_fileinfos(1, key_prefix="file.py", size=datetime.today())[0] + + src_files = [file_src] + dest_files = [file_dest] + + dest_location = S3ObjectLocation("mock-bucket-123", "mock-prefix") + factory = HubSyncRequestFactory(src_files, dest_files, dest_location, COMPARATOR) + + req = factory.create() + + assert req.files == src_files From 4554d34b0077d33a2d18e11ca0345978c19d70a7 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 7 Mar 2024 20:17:36 +0000 Subject: [PATCH 10/15] do not tolerate vulnerable --- .../jumpstart/curated_hub/accessors/public_model_data.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index 8712892af9..1f9d8683af 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -104,8 +104,6 @@ def _jumpstart_script_s3_uri(self, model_scope: str) -> str: model_id=self.model_specs.model_id, model_version=self.model_specs.version, script_scope=model_scope, - tolerate_vulnerable_model=True, - tolerate_deprecated_model=True, ) def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: @@ -115,6 +113,4 @@ def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: model_id=self.model_specs.model_id, model_version=self.model_specs.version, model_scope=model_scope, - tolerate_vulnerable_model=True, - tolerate_deprecated_model=True, ) From c7f3f961e54f8e9675b07ed2ed1b2d4281350f73 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 7 Mar 2024 20:28:14 +0000 Subject: [PATCH 11/15] remove prints --- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 2 +- src/sagemaker/jumpstart/curated_hub/sync/request.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index eccb02b4fd..6c745d1649 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -99,7 +99,7 @@ def _fetch_hub_bucket_name(self) -> str: JUMPSTART_LOGGER.warning( "There is not a Hub bucket associated with %s. Using %s", self.hub_name, - default_bucket_name, + hub_bucket_name, ) return hub_bucket_name diff --git a/src/sagemaker/jumpstart/curated_hub/sync/request.py b/src/sagemaker/jumpstart/curated_hub/sync/request.py index 5d3724da6e..11e18eddd2 100644 --- a/src/sagemaker/jumpstart/curated_hub/sync/request.py +++ b/src/sagemaker/jumpstart/curated_hub/sync/request.py @@ -113,13 +113,11 @@ def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: for src_file in self.src_files: # End of dest, yield remaining src_files if dest_done: - print("here????", src_file.location.key) yield src_file continue # We've identified two files that have the same name, further compare if self._is_same_file_name(src_file.location.key, dest_file.location.key): - print("wait i'm asdfd", src_file.location.key, dest_file.location.key) should_sync = self.comparator.determine_should_sync(src_file, dest_file) if should_sync: @@ -138,7 +136,6 @@ def _determine_files_to_copy(self) -> Generator[FileInfo, FileInfo, FileInfo]: if self._is_alphabetically_earlier_file_name( src_file.location.key, dest_file.location.key ): - print("wait i'm here", src_file.location.key, dest_file.location.key) yield src_file continue From 28c9186645cfbb795ef5e7924caac4f6a388c142 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 7 Mar 2024 21:08:40 +0000 Subject: [PATCH 12/15] handle multithreading progress bar --- .../jumpstart/curated_hub/accessors/multipartcopy.py | 9 +++++++-- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 11 ++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py index 96f7b77c61..e5d95e589c 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" from __future__ import absolute_import +from typing import Optional import boto3 import botocore @@ -67,6 +68,8 @@ def __init__( self, region: str, sync_request: HubSyncRequest, + label: Optional[str] = None, + thread_num: Optional[int] = 0 ): """Multi-part S3:Copy Handler initializer. @@ -75,9 +78,11 @@ def __init__( sync_request (HubSyncRequest): sync request object containing information required to perform the copy """ + self.label = label self.region = region self.files = sync_request.files self.dest_location = sync_request.destination + self.thread_num = thread_num config = botocore.config.Config(max_pool_connections=self.WORKERS) self.s3_client = boto3.client("s3", region_name=self.region, config=config) @@ -121,11 +126,11 @@ def execute(self): ) progress = tqdm.tqdm( - desc="JumpStart Sync", + desc=self.label, total=total_size, unit="B", unit_scale=1, - position=0, + position=self.thread_num, bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}", ) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 6c745d1649..a58cd34c67 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -21,6 +21,7 @@ from botocore import exceptions from botocore.client import BaseClient from packaging.version import Version +import tqdm from sagemaker.jumpstart import utils from sagemaker.jumpstart.curated_hub.accessors import file_generator @@ -314,9 +315,9 @@ def sync(self, model_list: List[Dict[str, str]]): max_workers=self._default_thread_pool_size, thread_name_prefix="import-models-to-curated-hub", ) as deploy_executor: - for model in models_to_sync: - task = deploy_executor.submit(self._sync_public_model_to_hub, model) - tasks.append(task) + for thread_num, model in enumerate(models_to_sync): + task = deploy_executor.submit(self._sync_public_model_to_hub, model, thread_num) + tasks.append(task) # Handle failed imports results = futures.wait(tasks) @@ -337,7 +338,7 @@ def sync(self, model_list: List[Dict[str, str]]): f"Failures when importing models to curated hub in parallel: {failed_imports}" ) - def _sync_public_model_to_hub(self, model: JumpStartModelInfo): + def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int): """Syncs a public JumpStart model version to the Hub. Runs in parallel.""" model_specs = utils.verify_model_region_and_return_specs( model_id=model.model_id, @@ -364,7 +365,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo): ).create() if len(sync_request.files) > 0: - MultiPartCopyHandler(region=self.region, sync_request=sync_request).execute() + MultiPartCopyHandler(thread_num=thread_num, sync_request=sync_request, region=self.region, label=dest_location.key).execute() else: JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version) From 97001cc260027afe45d26a270fa0a408fa72aebe Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 7 Mar 2024 21:25:01 +0000 Subject: [PATCH 13/15] update tests --- .../curated_hub/accessors/multipartcopy.py | 2 +- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 14 +++++++++----- .../jumpstart/curated_hub/test_curated_hub.py | 12 ++++++------ 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py index e5d95e589c..3cf20ed500 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -69,7 +69,7 @@ def __init__( region: str, sync_request: HubSyncRequest, label: Optional[str] = None, - thread_num: Optional[int] = 0 + thread_num: Optional[int] = 0, ): """Multi-part S3:Copy Handler initializer. diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index a58cd34c67..f9669fb0d9 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -21,7 +21,6 @@ from botocore import exceptions from botocore.client import BaseClient from packaging.version import Version -import tqdm from sagemaker.jumpstart import utils from sagemaker.jumpstart.curated_hub.accessors import file_generator @@ -315,9 +314,9 @@ def sync(self, model_list: List[Dict[str, str]]): max_workers=self._default_thread_pool_size, thread_name_prefix="import-models-to-curated-hub", ) as deploy_executor: - for thread_num, model in enumerate(models_to_sync): - task = deploy_executor.submit(self._sync_public_model_to_hub, model, thread_num) - tasks.append(task) + for thread_num, model in enumerate(models_to_sync): + task = deploy_executor.submit(self._sync_public_model_to_hub, model, thread_num) + tasks.append(task) # Handle failed imports results = futures.wait(tasks) @@ -365,7 +364,12 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int): ).create() if len(sync_request.files) > 0: - MultiPartCopyHandler(thread_num=thread_num, sync_request=sync_request, region=self.region, label=dest_location.key).execute() + MultiPartCopyHandler( + thread_num=thread_num, + sync_request=sync_request, + region=self.region, + label=dest_location.key, + ).execute() else: JUMPSTART_LOGGER.warning("Nothing to copy for %s v%s", model.model_id, model.version) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 5f86c25a65..67d1a7a1e4 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -163,8 +163,8 @@ def test_sync_kicks_off_parallel_syncs( mock_sync_public_models.assert_has_calls( [ - mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*")), - mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")), + mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0), + mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1), ] ) @@ -206,7 +206,7 @@ def test_sync_filters_models_that_exist_in_hub( hub.sync([model_one, model_two]) mock_sync_public_models.assert_called_once_with( - JumpStartModelInfo("mock-model-one-huggingface", "*") + JumpStartModelInfo("mock-model-one-huggingface", "*"), 0 ) @@ -252,8 +252,8 @@ def test_sync_updates_old_models_in_hub( mock_sync_public_models.assert_has_calls( [ - mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*")), - mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")), + mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0), + mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1), ] ) @@ -299,7 +299,7 @@ def test_sync_passes_newer_hub_models( hub.sync([model_one, model_two]) mock_sync_public_models.assert_called_once_with( - JumpStartModelInfo("mock-model-one-huggingface", "*") + JumpStartModelInfo("mock-model-one-huggingface", "*"), 0 ) From 27240a9cff8f0ee567bff3de795856f7b2532d88 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Fri, 8 Mar 2024 20:55:58 +0000 Subject: [PATCH 14/15] optimize function and add hub bucket prefix --- .../curated_hub/accessors/file_generator.py | 26 ++++---- .../curated_hub/accessors/multipartcopy.py | 1 + .../jumpstart/curated_hub/curated_hub.py | 42 +++++++------ .../jumpstart/curated_hub/test_curated_hub.py | 63 +++++++++++-------- 4 files changed, 74 insertions(+), 58 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py index 3e1ac849cb..bf4c671680 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py @@ -35,16 +35,16 @@ def generate_file_infos_from_s3_location( """ parameters = {"Bucket": location.bucket, "Prefix": location.key} response = s3_client.list_objects_v2(**parameters) - contents = response.get("Contents", None) + contents = response.get("Contents") if not contents: return [] files = [] for s3_obj in contents: - key: str = s3_obj.get("Key") - size: bytes = s3_obj.get("Size", None) - last_modified: str = s3_obj.get("LastModified", None) + key = s3_obj.get("Key") + size = s3_obj.get("Size") + last_modified = s3_obj.get("LastModified") files.append(FileInfo(location.bucket, key, size, last_modified)) return files @@ -71,28 +71,26 @@ def generate_file_infos_from_model_specs( if location_type == "prefix": parameters = {"Bucket": location.bucket, "Prefix": location.key} response = s3_client.list_objects_v2(**parameters) - contents = response.get("Contents", None) + contents = response.get("Contents") for s3_obj in contents: - key: str = s3_obj.get("Key") - size: bytes = s3_obj.get("Size", None) - last_modified: datetime = s3_obj.get("LastModified", None) - dependency_type: HubContentDependencyType = dependency + key = s3_obj.get("Key") + size = s3_obj.get("Size") + last_modified = s3_obj.get("LastModified") files.append( FileInfo( location.bucket, key, size, last_modified, - dependency_type, + dependency, ) ) elif location_type == "object": parameters = {"Bucket": location.bucket, "Key": location.key} response = s3_client.head_object(**parameters) - size: bytes = response.get("ContentLength", None) - last_updated: datetime = response.get("LastModified", None) - dependency_type: HubContentDependencyType = dependency + size = response.get("ContentLength") + last_updated = response.get("LastModified") files.append( - FileInfo(location.bucket, location.key, size, last_updated, dependency_type) + FileInfo(location.bucket, location.key, size, last_updated, dependency) ) return files diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py index 3cf20ed500..52174f9eeb 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py @@ -62,6 +62,7 @@ class MultiPartCopyHandler(object): """Multi Part Copy Handler class.""" WORKERS = 20 + # Config values from in S3:Copy MULTIPART_CONFIG = 8 * (1024**2) def __init__( diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index f9669fb0d9..e6f5d0c29e 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -13,6 +13,7 @@ """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import from concurrent import futures +from datetime import datetime import json import traceback from typing import Optional, Dict, List, Any @@ -50,6 +51,7 @@ HubContentDocument_v2, JumpStartModelInfo, S3ObjectLocation, + create_s3_object_reference_from_uri, ) @@ -73,20 +75,21 @@ def __init__( self.region = sagemaker_session.boto_region_name self._sagemaker_session = sagemaker_session self._default_thread_pool_size = 20 - self.hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() self._s3_client = self._get_s3_client() + self.hub_storage_location = self._generate_hub_storage_location(bucket_name) def _get_s3_client(self) -> BaseClient: - """Returns an S3 client.""" + """Returns an S3 client used for creating a HubContentDocument.""" return boto3.client("s3", region_name=self.region) def _fetch_hub_bucket_name(self) -> str: """Retrieves hub bucket name from Hub config if exists""" try: hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) - hub_bucket_prefix = hub_response["S3StorageConfig"].get("S3OutputPath", None) - if hub_bucket_prefix: - return hub_bucket_prefix.replace("s3://", "") + hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") + if hub_output_location: + location = create_s3_object_reference_from_uri(hub_output_location) + return location.bucket default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) JUMPSTART_LOGGER.warning( "There is not a Hub bucket associated with %s. Using %s", @@ -103,6 +106,12 @@ def _fetch_hub_bucket_name(self) -> str: ) return hub_bucket_name + def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: + """Generates an ``S3ObjectLocation`` given a Hub name.""" + hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() + curr_timestamp = datetime.now().timestamp() + return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") + def create( self, description: str, @@ -112,8 +121,8 @@ def create( ) -> Dict[str, str]: """Creates a hub with the given description""" - bucket_name = create_hub_bucket_if_it_does_not_exist( - self.hub_bucket_name, self._sagemaker_session + create_hub_bucket_if_it_does_not_exist( + self.hub_storage_location.bucket, self._sagemaker_session ) return self._sagemaker_session.create_hub( @@ -121,7 +130,7 @@ def create( hub_description=description, hub_display_name=display_name, hub_search_keywords=search_keywords, - s3_storage_config={"S3OutputPath": f"s3://{bucket_name}"}, + s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, tags=tags, ) @@ -226,7 +235,7 @@ def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: return js_models_in_hub def _determine_models_to_sync( - self, model_list: List[JumpStartModelInfo], models_in_hub + self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, Any] ) -> List[JumpStartModelInfo]: """Determines which models from `sync` params to sync into the CuratedHub. @@ -240,14 +249,7 @@ def _determine_models_to_sync( """ models_to_sync = [] for model in model_list: - matched_model = next( - ( - hub_model - for hub_model in models_in_hub - if hub_model and hub_model["name"] == model.model_id - ), - None, - ) + matched_model = models_in_hub.get(model.model_id) # Model does not exist in Hub, sync if not matched_model: @@ -300,8 +302,9 @@ def sync(self, model_list: List[Dict[str, str]]): model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"])) js_models_in_hub = self._get_jumpstart_models_in_hub() + mapped_models_in_hub = { model["name"]: model for model in js_models_in_hub } - models_to_sync = self._determine_models_to_sync(model_version_list, js_models_in_hub) + models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub) JUMPSTART_LOGGER.warning( "Syncing the following models into Hub %s: %s", self.hub_name, models_to_sync ) @@ -349,7 +352,8 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int): studio_specs = self._fetch_studio_specs(model_specs=model_specs) dest_location = S3ObjectLocation( - bucket=self.hub_bucket_name, key=f"{model.model_id}/{model.version}" + bucket=self.hub_storage_location.bucket, + key=f"{self.hub_storage_location.key}/{model.model_id}/{model.version}" ) src_files = file_generator.generate_file_infos_from_model_specs( model_specs, studio_specs, self.region, self._s3_client diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 67d1a7a1e4..2367d3d7ba 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -12,22 +12,25 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import from copy import deepcopy +import datetime from unittest import mock from unittest.mock import patch import pytest from mock import Mock from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub -from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo +from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo, S3ObjectLocation from sagemaker.jumpstart.types import JumpStartModelSpecs from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec + REGION = "us-east-1" ACCOUNT_ID = "123456789123" HUB_NAME = "mock-hub-name" MODULE_PATH = "sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub" +FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00) @pytest.fixture() def sagemaker_session(): @@ -39,7 +42,7 @@ def sagemaker_session(): "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" ) sagemaker_session_mock.describe_hub.return_value = { - "S3StorageConfig": {"S3OutputPath": "mock-bucket-123"} + "S3StorageConfig": {"S3OutputPath": "s3://mock-bucket-123"} } sagemaker_session_mock.account_id.return_value = ACCOUNT_ID return sagemaker_session_mock @@ -66,7 +69,9 @@ def test_instantiates(sagemaker_session): ), ], ) +@patch("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location") def test_create_with_no_bucket_name( + mock_generate_hub_storage_location, sagemaker_session, hub_name, hub_description, @@ -75,10 +80,12 @@ def test_create_with_no_bucket_name( hub_search_keywords, tags, ): + storage_location = S3ObjectLocation("sagemaker-hubs-us-east-1-123456789123",f"{hub_name}-{FAKE_TIME.timestamp()}") + mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) sagemaker_session.describe_hub.return_value = { - "S3StorageConfig": {"S3OutputPath": hub_bucket_name} + "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} } hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { @@ -86,7 +93,9 @@ def test_create_with_no_bucket_name( "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": {"S3OutputPath": "s3://sagemaker-hubs-us-east-1-123456789123"}, + "s3_storage_config": { + "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" + }, "tags": tags, } response = hub.create( @@ -113,7 +122,9 @@ def test_create_with_no_bucket_name( ), ], ) +@patch("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location") def test_create_with_bucket_name( + mock_generate_hub_storage_location, sagemaker_session, hub_name, hub_description, @@ -122,6 +133,8 @@ def test_create_with_bucket_name( hub_search_keywords, tags, ): + storage_location = S3ObjectLocation(hub_bucket_name,f"{hub_name}-{FAKE_TIME.timestamp()}") + mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) hub = CuratedHub( @@ -132,7 +145,7 @@ def test_create_with_bucket_name( "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": {"S3OutputPath": "s3://mock-bucket-123"}, + "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, "tags": tags, } response = hub.create( @@ -397,8 +410,8 @@ def test_determine_models_to_sync(sagemaker_session): hub_name = "mock_hub_name" hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) - js_models_in_hub = [ - { + js_model_map = { + "mock-model-two-pytorch": { "name": "mock-model-two-pytorch", "version": "1.0.1", "search_keywords": [ @@ -406,7 +419,7 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:1.0.2", ], }, - { + "mock-model-four-huggingface": { "name": "mock-model-four-huggingface", "version": "2.0.2", "search_keywords": [ @@ -414,15 +427,15 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:2.0.2", ], }, - ] + } model_one = JumpStartModelInfo("mock-model-one-huggingface", "1.2.3") model_two = JumpStartModelInfo("mock-model-two-pytorch", "1.0.2") # No model_one, older model_two - res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) assert res == [model_one, model_two] - js_models_in_hub = [ - { + js_model_map = { + "mock-model-two-pytorch": { "name": "mock-model-two-pytorch", "version": "1.0.3", "search_keywords": [ @@ -430,7 +443,7 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:1.0.3", ], }, - { + "mock-model-four-huggingface": { "name": "mock-model-four-huggingface", "version": "2.0.2", "search_keywords": [ @@ -438,13 +451,13 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:2.0.2", ], }, - ] + } # No model_one, newer model_two - res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) assert res == [model_one] - js_models_in_hub = [ - { + js_model_map = { + "mock-model-one-huggingface": { "name": "mock-model-one-huggingface", "version": "1.2.3", "search_keywords": [ @@ -452,7 +465,7 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:1.2.3", ], }, - { + "mock-model-two-pytorch": { "name": "mock-model-two-pytorch", "version": "1.0.2", "search_keywords": [ @@ -460,13 +473,13 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:1.0.2", ], }, - ] + } # Same model_one, same model_two - res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) assert res == [] - js_models_in_hub = [ - { + js_model_map = { + "mock-model-one-huggingface": { "name": "mock-model-one-huggingface", "version": "1.2.1", "search_keywords": [ @@ -474,7 +487,7 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:1.2.1", ], }, - { + "mock-model-two-pytorch": { "name": "mock-model-two-pytorch", "version": "1.0.2", "search_keywords": [ @@ -482,7 +495,7 @@ def test_determine_models_to_sync(sagemaker_session): "@jumpstart-model-version:1.0.2", ], }, - ] + } # Old model_one, same model_two - res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub) + res = hub._determine_models_to_sync([model_one, model_two], js_model_map) assert res == [model_one] From ce73f620f278cb44a3e58482646a60ece5a23fcf Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Fri, 8 Mar 2024 21:09:33 +0000 Subject: [PATCH 15/15] docstrings and linting --- .../curated_hub/accessors/file_generator.py | 5 +---- .../curated_hub/accessors/public_model_data.py | 2 +- .../jumpstart/curated_hub/curated_hub.py | 10 +++++----- .../jumpstart/curated_hub/sync/request.py | 18 +++++++++--------- src/sagemaker/jumpstart/curated_hub/types.py | 11 ----------- src/sagemaker/jumpstart/curated_hub/utils.py | 12 ++++++++++++ .../jumpstart/curated_hub/test_curated_hub.py | 7 +++++-- 7 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py index bf4c671680..0393b4234a 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from typing import Any, Dict, List -from datetime import datetime from botocore.client import BaseClient from sagemaker.jumpstart.curated_hub.types import ( @@ -90,7 +89,5 @@ def generate_file_infos_from_model_specs( response = s3_client.head_object(**parameters) size = response.get("ContentLength") last_updated = response.get("LastModified") - files.append( - FileInfo(location.bucket, location.key, size, last_updated, dependency) - ) + files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency)) return files diff --git a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py index 1f9d8683af..89e3a2f108 100644 --- a/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py +++ b/src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py @@ -17,8 +17,8 @@ from sagemaker.jumpstart.curated_hub.types import ( HubContentDependencyType, S3ObjectLocation, - create_s3_object_reference_from_uri, ) +from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import JumpStartModelSpecs from sagemaker.jumpstart.utils import get_jumpstart_content_bucket diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index e6f5d0c29e..a35948f138 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -46,12 +46,12 @@ from sagemaker.jumpstart.curated_hub.utils import ( create_hub_bucket_if_it_does_not_exist, generate_default_hub_bucket_name, + create_s3_object_reference_from_uri, ) from sagemaker.jumpstart.curated_hub.types import ( HubContentDocument_v2, JumpStartModelInfo, S3ObjectLocation, - create_s3_object_reference_from_uri, ) @@ -302,7 +302,7 @@ def sync(self, model_list: List[Dict[str, str]]): model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"])) js_models_in_hub = self._get_jumpstart_models_in_hub() - mapped_models_in_hub = { model["name"]: model for model in js_models_in_hub } + mapped_models_in_hub = {model["name"]: model for model in js_models_in_hub} models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub) JUMPSTART_LOGGER.warning( @@ -316,9 +316,9 @@ def sync(self, model_list: List[Dict[str, str]]): with futures.ThreadPoolExecutor( max_workers=self._default_thread_pool_size, thread_name_prefix="import-models-to-curated-hub", - ) as deploy_executor: + ) as import_executor: for thread_num, model in enumerate(models_to_sync): - task = deploy_executor.submit(self._sync_public_model_to_hub, model, thread_num) + task = import_executor.submit(self._sync_public_model_to_hub, model, thread_num) tasks.append(task) # Handle failed imports @@ -353,7 +353,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int): dest_location = S3ObjectLocation( bucket=self.hub_storage_location.bucket, - key=f"{self.hub_storage_location.key}/{model.model_id}/{model.version}" + key=f"{self.hub_storage_location.key}/curated_models/{model.model_id}/{model.version}", ) src_files = file_generator.generate_file_infos_from_model_specs( model_specs, studio_specs, self.region, self._s3_client diff --git a/src/sagemaker/jumpstart/curated_hub/sync/request.py b/src/sagemaker/jumpstart/curated_hub/sync/request.py index 11e18eddd2..0e620432ce 100644 --- a/src/sagemaker/jumpstart/curated_hub/sync/request.py +++ b/src/sagemaker/jumpstart/curated_hub/sync/request.py @@ -35,16 +35,21 @@ def __init__( ): """Contains information required to sync data into a Hub. - Returns: - :var: files (List[FileInfo]): Files that shoudl be synced. - :var: destination (S3ObjectLocation): Location to which to sync the files. + Attrs: + files (List[FileInfo]): Files that should be synced. + destination (S3ObjectLocation): Location to which to sync the files. """ self.files = list(files_to_copy) self.destination = destination class HubSyncRequestFactory: - """Generates a ``HubSyncRequest`` which is required to sync data into a Hub.""" + """Generates a ``HubSyncRequest`` which is required to sync data into a Hub. + + Creates a ``HubSyncRequest`` class containing: + :var: files (List[FileInfo]): Files that should be synced. + :var: destination (S3ObjectLocation): Location to which to sync the files. + """ def __init__( self, @@ -59,11 +64,6 @@ def __init__( src_files (List[FileInfo]): List of files to sync to destination bucket dest_files (List[FileInfo]): List of files already in destination bucket destination (S3ObjectLocation): S3 destination for copied data - - Returns: - ``HubSyncRequest`` class containing: - :var: files (List[FileInfo]): Files that shoudl be synced. - :var: destination (S3ObjectLocation): Location to which to sync the files. """ self.comparator = comparator self.destination = destination diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index 28c5f39300..99f9cfdc63 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -18,7 +18,6 @@ from datetime import datetime from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs -from sagemaker.s3_utils import parse_s3_url @dataclass @@ -40,16 +39,6 @@ def get_uri(self) -> str: return f"s3://{self.bucket}/{self.key}" -def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: - """Utiity to help generate an S3 object reference""" - bucket, key = parse_s3_url(s3_uri) - - return S3ObjectLocation( - bucket=bucket, - key=key, - ) - - @dataclass class JumpStartModelInfo: """Helper class for storing JumpStart model info.""" diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index ac01da45ca..b116411801 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -14,6 +14,8 @@ from __future__ import absolute_import import re from typing import Optional +from sagemaker.jumpstart.curated_hub.types import S3ObjectLocation +from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session from sagemaker.utils import aws_partition from sagemaker.jumpstart.types import ( @@ -131,6 +133,16 @@ def generate_default_hub_bucket_name( return f"sagemaker-hubs-{region}-{account_id}" +def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: + """Utiity to help generate an S3 object reference""" + bucket, key = parse_s3_url(s3_uri) + + return S3ObjectLocation( + bucket=bucket, + key=key, + ) + + def create_hub_bucket_if_it_does_not_exist( bucket_name: Optional[str] = None, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 2367d3d7ba..16a5588094 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -32,6 +32,7 @@ FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00) + @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session") @@ -80,7 +81,9 @@ def test_create_with_no_bucket_name( hub_search_keywords, tags, ): - storage_location = S3ObjectLocation("sagemaker-hubs-us-east-1-123456789123",f"{hub_name}-{FAKE_TIME.timestamp()}") + storage_location = S3ObjectLocation( + "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" + ) mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) @@ -133,7 +136,7 @@ def test_create_with_bucket_name( hub_search_keywords, tags, ): - storage_location = S3ObjectLocation(hub_bucket_name,f"{hub_name}-{FAKE_TIME.timestamp()}") + storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub)