Skip to content

MultiPartCopy with Sync Algorithm #4475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 12, 2024
106 changes: 52 additions & 54 deletions src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# language governing permissions and limitations under the License.
"""This module contains important utilities related to HubContent data files."""
from __future__ import absolute_import
from functools import singledispatchmethod
from typing import Any, Dict, List, Optional

from datetime import datetime
from botocore.client import BaseClient

from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType
Expand All @@ -33,26 +33,19 @@ def __init__(
self.s3_client = s3_client
self.studio_specs = studio_specs

@singledispatchmethod
def format(self, file_input) -> List[FileInfo]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?

"""Implement."""
# pylint: disable=W0107
pass

@format.register
def _(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Something."""
files = self.s3_format(file_input)
return files
"""Dispatch method that is implemented in below registered functions."""
raise NotImplementedError

@format.register
def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]:
"""Something."""
files = self.specs_format(file_input, self.studio_specs)
return files

def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Retrieves data from a bucket and formats into FileInfo"""
class S3PathFileGenerator(FileGenerator):
"""Utility class to help format all objects in an S3 bucket."""

def format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Retrieves data from an S3 bucket and formats into FileInfo.

Returns a list of ``FileInfo`` objects from the specified bucket location.
"""
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key}
response = self.s3_client.list_objects_v2(**parameters)
contents = response.get("Contents", None)
Expand All @@ -66,47 +59,52 @@ def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
key: str = s3_obj.get("Key")
size: bytes = s3_obj.get("Size", None)
last_modified: str = s3_obj.get("LastModified", None)
files.append(FileInfo(key, size, last_modified))
files.append(FileInfo(file_input.bucket, key, size, last_modified))
return files

def specs_format(
self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any]
) -> List[FileInfo]:
"""Collects data locations from JumpStart public model specs and
converts into FileInfo.

class ModelSpecsFileGenerator(FileGenerator):
"""Utility class to help format all data paths from JumpStart public model specs."""

def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]:
"""Collects data locations from JumpStart public model specs and converts into FileInfo`.

Returns a list of ``FileInfo`` objects from dependencies found in the public
model specs.
"""
public_model_data_accessor = PublicModelDataAccessor(
region=self.region, model_specs=file_input, studio_specs=studio_specs
region=self.region, model_specs=file_input, studio_specs=self.studio_specs
)
function_table = {
HubContentDependencyType.INFERENCE_ARTIFACT: (
public_model_data_accessor.get_inference_artifact_s3_reference
),
HubContentDependencyType.TRAINING_ARTIFACT: (
public_model_data_accessor.get_training_artifact_s3_reference
),
HubContentDependencyType.INFERNECE_SCRIPT: (
public_model_data_accessor.get_inference_script_s3_reference
),
HubContentDependencyType.TRAINING_SCRIPT: (
public_model_data_accessor.get_training_script_s3_reference
),
HubContentDependencyType.DEFAULT_TRAINING_DATASET: (
public_model_data_accessor.get_default_training_dataset_s3_reference
),
HubContentDependencyType.DEMO_NOTEBOOK: (
public_model_data_accessor.get_demo_notebook_s3_reference
),
HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference,
}
files = []
for dependency in HubContentDependencyType:
location = function_table[dependency]()
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = self.s3_client.head_object(**parameters)
key: str = location.key
size: bytes = response.get("ContentLength", None)
last_updated: str = response.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(FileInfo(key, size, last_updated, dependency_type))
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)

# Prefix
if location.key[-1] == "/":
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = self.s3_client.list_objects_v2(**parameters)
contents = response.get("Contents", None)
for s3_obj in contents:
key: str = s3_obj.get("Key")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the typing Optional[str]?

size: bytes = s3_obj.get("Size", None)
last_modified: datetime = s3_obj.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too much indentation here, can we breakup into helper functions?

FileInfo(
location.bucket,
key,
size,
last_modified,
dependency_type,
)
)
else:
parameters = {"Bucket": location.bucket, "Key": location.key}
response = self.s3_client.head_object(**parameters)
size: bytes = response.get("ContentLength", None)
last_updated: datetime = response.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(
FileInfo(location.bucket, location.key, size, last_updated, dependency_type)
)
return files
26 changes: 16 additions & 10 deletions src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,38 @@
from enum import Enum
from dataclasses import dataclass
from typing import Optional
from datetime import datetime

from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation


class HubContentDependencyType(str, Enum):
"""Enum class for HubContent dependency names"""

INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT"
TRAINING_ARTIFACT = "TRAINING_ARTIFACT"
INFERNECE_SCRIPT = "INFERENCE_SCRIPT"
TRAINING_SCRIPT = "TRAINING_SCRIPT"
DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET"
DEMO_NOTEBOOK = "DEMO_NOTEBOOK"
MARKDOWN = "MARKDOWN"
INFERENCE_ARTIFACT = "inference_artifact_s3_reference"
TRAINING_ARTIFACT = "training_artifact_s3_reference"
INFERENCE_SCRIPT = "inference_script_s3_reference"
TRAINING_SCRIPT = "training_script_s3_reference"
DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference"
DEMO_NOTEBOOK = "demo_notebook_s3_reference"
MARKDOWN = "markdown_s3_reference"


@dataclass
class FileInfo:
"""Data class for additional S3 file info."""

location: S3ObjectLocation

def __init__(
self,
name: str,
bucket: str,
key: str,
size: Optional[bytes],
last_updated: Optional[str],
last_updated: Optional[datetime],
dependecy_type: Optional[HubContentDependencyType] = None,
):
self.name = name
self.location = S3ObjectLocation(bucket, key)
self.size = size
self.last_updated = last_updated
self.dependecy_type = dependecy_type
128 changes: 128 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``."""
from __future__ import absolute_import
from typing import List

import boto3
import botocore
import tqdm

from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation

s3transfer = boto3.s3.transfer


# pylint: disable=R1705,R1710
def human_readable_size(value):
"""Convert a size in bytes into a human readable format.

For example::

>>> human_readable_size(1)
'1 Byte'
>>> human_readable_size(10)
'10 Bytes'
>>> human_readable_size(1024)
'1.0 KiB'
>>> human_readable_size(1024 * 1024)
'1.0 MiB'

:param value: The size in bytes.
:return: The size in a human readable format based on base-2 units.

"""
base = 1024
bytes_int = float(value)

if bytes_int == 1:
return "1 Byte"
elif bytes_int < base:
return "%d Bytes" % bytes_int

for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")):
unit = base ** (i + 2)
if round((bytes_int / unit) * base) < base:
return "%.1f %s" % ((base * bytes_int / unit), suffix)


class MultiPartCopyHandler(object):
"""Multi Part Copy Handler class."""

WORKERS = 20
MULTIPART_CONFIG = 8 * (1024**2)

def __init__(
self,
region: str,
files: List[FileInfo],
dest_location: S3ObjectLocation,
):
"""Something."""
self.region = region
self.files = files
self.dest_location = dest_location

config = botocore.config.Config(max_pool_connections=self.WORKERS)
self.s3_client = boto3.client("s3", region_name=self.region, config=config)
transfer_config = s3transfer.TransferConfig(
multipart_threshold=self.MULTIPART_CONFIG,
multipart_chunksize=self.MULTIPART_CONFIG,
max_bandwidth=True,
use_threads=True,
max_concurrency=self.WORKERS,
)
self.transfer_manager = s3transfer.create_transfer_manager(
client=self.s3_client, config=transfer_config
)

def _copy_file(self, file: FileInfo, update_fn):
"""Something."""
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key}
result = self.transfer_manager.copy(
bucket=self.dest_location.bucket,
key=f"{self.dest_location.key}/{file.location.key}",
copy_source=copy_source,
subscribers=[
s3transfer.ProgressCallbackInvoker(update_fn),
],
)
result.result()

def call(self):
"""Something."""
total_size = sum([file.size for file in self.files])
JUMPSTART_LOGGER.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think should be a warning. can we modify JUMPSTART_LOGGER so that if you doJUMPSTART_LOGGER.info("blah", stdout=True), then it gets printed? warning gives a negative connotation

"Copying %s files (%s) into %s/%s",
len(self.files),
human_readable_size(total_size),
self.dest_location.bucket,
self.dest_location.key,
)

progress = tqdm.tqdm(
desc="JumpStart Sync",
total=total_size,
unit="B",
unit_scale=1,
position=0,
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}",
)

for file in self.files:
self._copy_file(file, progress.update)

self.transfer_manager.shutdown()
progress.close()
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import dataclass
from typing import Dict

from sagemaker.s3_utils import parse_s3_url


@dataclass
class S3ObjectLocation:
Expand All @@ -37,10 +39,9 @@ def get_uri(self) -> str:

def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
"""Utiity to help generate an S3 object reference"""
uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1)
uri_split = uri_with_s3_prefix_removed.split("/")
bucket, key = parse_s3_url(s3_uri)

return S3ObjectLocation(
bucket=uri_split[0],
key="/".join(uri_split[1:]) if len(uri_split) > 1 else "",
bucket=bucket,
key=key,
)
Loading