Skip to content

MultiPartCopy with Sync Algorithm #4475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 12, 2024
11 changes: 8 additions & 3 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(
Default: None (no config).
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
used for SageMaker interactions. Default: Session in region associated with boto3 session.
used for SageMaker interactions. Default: Session in region associated with boto3
session.
"""

self._region = region
Expand Down Expand Up @@ -358,7 +359,9 @@ def _retrieval_function(
hub_content_type=data_type
)

model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
model_specs = JumpStartModelSpecs(
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
)

utils.emit_logs_based_on_model_specs(
model_specs,
Expand All @@ -372,7 +375,9 @@ def _retrieval_function(
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
hub_description = DescribeHubResponse(response)
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
return JumpStartCachedContentValue(
formatted_content=DescribeHubResponse(hub_description)
)
raise ValueError(
f"Bad value for key '{key}': must be in ",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"
Expand Down
Empty file.
95 changes: 95 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important utilities related to HubContent data files."""
from __future__ import absolute_import
from typing import Any, Dict, List

from datetime import datetime
from botocore.client import BaseClient

from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
from sagemaker.jumpstart.types import JumpStartModelSpecs


def generate_file_infos_from_s3_location(
location: S3ObjectLocation, s3_client: BaseClient
) -> List[FileInfo]:
"""Lists objects from an S3 bucket and formats into FileInfo.

Returns a list of ``FileInfo`` objects from the specified bucket location.
"""
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = s3_client.list_objects_v2(**parameters)
contents = response.get("Contents", None)

if not contents:
return []

files = []
for s3_obj in contents:
key: str = s3_obj.get("Key")
size: bytes = s3_obj.get("Size", None)
last_modified: str = s3_obj.get("LastModified", None)
files.append(FileInfo(location.bucket, key, size, last_modified))
return files


def generate_file_infos_from_model_specs(
model_specs: JumpStartModelSpecs,
studio_specs: Dict[str, Any],
region: str,
s3_client: BaseClient,
) -> List[FileInfo]:
"""Collects data locations from JumpStart public model specs and converts into `FileInfo`.

Returns a list of `FileInfo` objects from dependencies found in the public
model specs.
"""
public_model_data_accessor = PublicModelDataAccessor(
region=region, model_specs=model_specs, studio_specs=studio_specs
)
files = []
for dependency in HubContentDependencyType:
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)

# Prefix
if location.key.endswith("/"):
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = s3_client.list_objects_v2(**parameters)
contents = response.get("Contents", None)
for s3_obj in contents:
key: str = s3_obj.get("Key")
size: bytes = s3_obj.get("Size", None)
last_modified: datetime = s3_obj.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(
FileInfo(
location.bucket,
key,
size,
last_modified,
dependency_type,
)
)
else:
parameters = {"Bucket": location.bucket, "Key": location.key}
response = s3_client.head_object(**parameters)
size: bytes = response.get("ContentLength", None)
last_updated: datetime = response.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(
FileInfo(location.bucket, location.key, size, last_updated, dependency_type)
)
return files
53 changes: 53 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important details related to HubContent data files."""
from __future__ import absolute_import

from enum import Enum
from dataclasses import dataclass
from typing import Optional
from datetime import datetime

from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation


class HubContentDependencyType(str, Enum):
"""Enum class for HubContent dependency names"""

INFERENCE_ARTIFACT = "inference_artifact_s3_reference"
TRAINING_ARTIFACT = "training_artifact_s3_reference"
INFERENCE_SCRIPT = "inference_script_s3_reference"
TRAINING_SCRIPT = "training_script_s3_reference"
DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference"
DEMO_NOTEBOOK = "demo_notebook_s3_reference"
MARKDOWN = "markdown_s3_reference"


@dataclass
class FileInfo:
"""Data class for additional S3 file info."""

location: S3ObjectLocation

def __init__(
self,
bucket: str,
key: str,
size: Optional[bytes],
last_updated: Optional[datetime],
dependecy_type: Optional[HubContentDependencyType] = None,
):
self.location = S3ObjectLocation(bucket, key)
self.size = size
self.last_updated = last_updated
self.dependecy_type = dependecy_type
128 changes: 128 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``."""
from __future__ import absolute_import
from typing import List

import boto3
import botocore
import tqdm

from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation

s3transfer = boto3.s3.transfer


# pylint: disable=R1705,R1710
def human_readable_size(value: int) -> str:
"""Convert a size in bytes into a human readable format.

For example::

>>> human_readable_size(1)
'1 Byte'
>>> human_readable_size(10)
'10 Bytes'
>>> human_readable_size(1024)
'1.0 KiB'
>>> human_readable_size(1024 * 1024)
'1.0 MiB'

:param value: The size in bytes.
:return: The size in a human readable format based on base-2 units.

"""
base = 1024
bytes_int = float(value)

if bytes_int == 1:
return "1 Byte"
elif bytes_int < base:
return "%d Bytes" % bytes_int

for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")):
unit = base ** (i + 2)
if round((bytes_int / unit) * base) < base:
return "%.1f %s" % ((base * bytes_int / unit), suffix)


class MultiPartCopyHandler(object):
"""Multi Part Copy Handler class."""

WORKERS = 20
MULTIPART_CONFIG = 8 * (1024**2)

def __init__(
self,
region: str,
files: List[FileInfo],
dest_location: S3ObjectLocation,
):
"""Something."""
self.region = region
self.files = files
self.dest_location = dest_location

config = botocore.config.Config(max_pool_connections=self.WORKERS)
self.s3_client = boto3.client("s3", region_name=self.region, config=config)
transfer_config = s3transfer.TransferConfig(
multipart_threshold=self.MULTIPART_CONFIG,
multipart_chunksize=self.MULTIPART_CONFIG,
max_bandwidth=True,
use_threads=True,
max_concurrency=self.WORKERS,
)
self.transfer_manager = s3transfer.create_transfer_manager(
client=self.s3_client, config=transfer_config
)

def _copy_file(self, file: FileInfo, update_fn):
"""Something."""
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key}
result = self.transfer_manager.copy(
bucket=self.dest_location.bucket,
key=f"{self.dest_location.key}/{file.location.key}",
copy_source=copy_source,
subscribers=[
s3transfer.ProgressCallbackInvoker(update_fn),
],
)
result.result()

def execute(self):
"""Something."""
total_size = sum([file.size for file in self.files])
JUMPSTART_LOGGER.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think should be a warning. can we modify JUMPSTART_LOGGER so that if you doJUMPSTART_LOGGER.info("blah", stdout=True), then it gets printed? warning gives a negative connotation

"Copying %s files (%s) into %s/%s",
len(self.files),
human_readable_size(total_size),
self.dest_location.bucket,
self.dest_location.key,
)

progress = tqdm.tqdm(
desc="JumpStart Sync",
total=total_size,
unit="B",
unit_scale=1,
position=0,
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}",
)

for file in self.files:
self._copy_file(file, progress.update)

self.transfer_manager.shutdown()
progress.close()
47 changes: 47 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module utilites to assist S3 client calls for the Curated Hub."""
from __future__ import absolute_import
from dataclasses import dataclass
from typing import Dict

from sagemaker.s3_utils import parse_s3_url


@dataclass
class S3ObjectLocation:
"""Helper class for S3 object references"""

bucket: str
key: str

def format_for_s3_copy(self) -> Dict[str, str]:
"""Returns a dict formatted for S3 copy calls"""
return {
"Bucket": self.bucket,
"Key": self.key,
}

def get_uri(self) -> str:
"""Returns the s3 URI"""
return f"s3://{self.bucket}/{self.key}"


def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
"""Utiity to help generate an S3 object reference"""
bucket, key = parse_s3_url(s3_uri)

return S3ObjectLocation(
bucket=bucket,
key=key,
)
Loading