-
Notifications
You must be signed in to change notification settings - Fork 1.2k
MultiPartCopy with Sync Algorithm #4475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
344d26b
374c638
67d8ec8
2fa0503
30c2b91
ef57f14
c44acd2
297d1b6
18a5728
4554d34
c7f3f96
28c9186
97001cc
27240a9
ce73f62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,9 +12,9 @@ | |
# language governing permissions and limitations under the License. | ||
"""This module contains important utilities related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
from functools import singledispatchmethod | ||
from typing import Any, Dict, List, Optional | ||
|
||
from datetime import datetime | ||
from botocore.client import BaseClient | ||
|
||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType | ||
|
@@ -33,26 +33,19 @@ def __init__( | |
self.s3_client = s3_client | ||
self.studio_specs = studio_specs | ||
|
||
@singledispatchmethod | ||
def format(self, file_input) -> List[FileInfo]: | ||
"""Implement.""" | ||
# pylint: disable=W0107 | ||
pass | ||
|
||
@format.register | ||
def _(self, file_input: S3ObjectLocation) -> List[FileInfo]: | ||
"""Something.""" | ||
files = self.s3_format(file_input) | ||
return files | ||
"""Dispatch method that is implemented in below registered functions.""" | ||
raise NotImplementedError | ||
|
||
@format.register | ||
def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: | ||
"""Something.""" | ||
files = self.specs_format(file_input, self.studio_specs) | ||
return files | ||
|
||
def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: | ||
"""Retrieves data from a bucket and formats into FileInfo""" | ||
class S3PathFileGenerator(FileGenerator): | ||
"""Utility class to help format all objects in an S3 bucket.""" | ||
|
||
def format(self, file_input: S3ObjectLocation) -> List[FileInfo]: | ||
"""Retrieves data from an S3 bucket and formats into FileInfo. | ||
|
||
Returns a list of ``FileInfo`` objects from the specified bucket location. | ||
""" | ||
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key} | ||
response = self.s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
|
@@ -66,47 +59,52 @@ def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]: | |
key: str = s3_obj.get("Key") | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
size: bytes = s3_obj.get("Size", None) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
last_modified: str = s3_obj.get("LastModified", None) | ||
files.append(FileInfo(key, size, last_modified)) | ||
files.append(FileInfo(file_input.bucket, key, size, last_modified)) | ||
return files | ||
|
||
def specs_format( | ||
self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any] | ||
) -> List[FileInfo]: | ||
"""Collects data locations from JumpStart public model specs and | ||
converts into FileInfo. | ||
|
||
class ModelSpecsFileGenerator(FileGenerator): | ||
"""Utility class to help format all data paths from JumpStart public model specs.""" | ||
|
||
def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: | ||
"""Collects data locations from JumpStart public model specs and converts into FileInfo`. | ||
|
||
Returns a list of ``FileInfo`` objects from dependencies found in the public | ||
model specs. | ||
""" | ||
public_model_data_accessor = PublicModelDataAccessor( | ||
region=self.region, model_specs=file_input, studio_specs=studio_specs | ||
region=self.region, model_specs=file_input, studio_specs=self.studio_specs | ||
) | ||
function_table = { | ||
HubContentDependencyType.INFERENCE_ARTIFACT: ( | ||
public_model_data_accessor.get_inference_artifact_s3_reference | ||
), | ||
HubContentDependencyType.TRAINING_ARTIFACT: ( | ||
public_model_data_accessor.get_training_artifact_s3_reference | ||
), | ||
HubContentDependencyType.INFERNECE_SCRIPT: ( | ||
public_model_data_accessor.get_inference_script_s3_reference | ||
), | ||
HubContentDependencyType.TRAINING_SCRIPT: ( | ||
public_model_data_accessor.get_training_script_s3_reference | ||
), | ||
HubContentDependencyType.DEFAULT_TRAINING_DATASET: ( | ||
public_model_data_accessor.get_default_training_dataset_s3_reference | ||
), | ||
HubContentDependencyType.DEMO_NOTEBOOK: ( | ||
public_model_data_accessor.get_demo_notebook_s3_reference | ||
), | ||
HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference, | ||
} | ||
files = [] | ||
for dependency in HubContentDependencyType: | ||
location = function_table[dependency]() | ||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = self.s3_client.head_object(**parameters) | ||
key: str = location.key | ||
size: bytes = response.get("ContentLength", None) | ||
last_updated: str = response.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append(FileInfo(key, size, last_updated, dependency_type)) | ||
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) | ||
|
||
# Prefix | ||
if location.key[-1] == "/": | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = self.s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
for s3_obj in contents: | ||
key: str = s3_obj.get("Key") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't the typing |
||
size: bytes = s3_obj.get("Size", None) | ||
last_modified: datetime = s3_obj.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. too much indentation here, can we breakup into helper functions? |
||
FileInfo( | ||
location.bucket, | ||
key, | ||
size, | ||
last_modified, | ||
dependency_type, | ||
) | ||
) | ||
else: | ||
parameters = {"Bucket": location.bucket, "Key": location.key} | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
response = self.s3_client.head_object(**parameters) | ||
size: bytes = response.get("ContentLength", None) | ||
last_updated: datetime = response.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append( | ||
FileInfo(location.bucket, location.key, size, last_updated, dependency_type) | ||
) | ||
return files |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from __future__ import absolute_import | ||
from typing import List | ||
|
||
import boto3 | ||
import botocore | ||
import tqdm | ||
|
||
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER | ||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
|
||
s3transfer = boto3.s3.transfer | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# pylint: disable=R1705,R1710 | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def human_readable_size(value): | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Convert a size in bytes into a human readable format. | ||
|
||
For example:: | ||
|
||
>>> human_readable_size(1) | ||
'1 Byte' | ||
>>> human_readable_size(10) | ||
'10 Bytes' | ||
>>> human_readable_size(1024) | ||
'1.0 KiB' | ||
>>> human_readable_size(1024 * 1024) | ||
'1.0 MiB' | ||
|
||
:param value: The size in bytes. | ||
:return: The size in a human readable format based on base-2 units. | ||
|
||
""" | ||
base = 1024 | ||
bytes_int = float(value) | ||
|
||
if bytes_int == 1: | ||
return "1 Byte" | ||
elif bytes_int < base: | ||
return "%d Bytes" % bytes_int | ||
|
||
for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")): | ||
unit = base ** (i + 2) | ||
if round((bytes_int / unit) * base) < base: | ||
return "%.1f %s" % ((base * bytes_int / unit), suffix) | ||
|
||
|
||
class MultiPartCopyHandler(object): | ||
"""Multi Part Copy Handler class.""" | ||
|
||
WORKERS = 20 | ||
MULTIPART_CONFIG = 8 * (1024**2) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
region: str, | ||
files: List[FileInfo], | ||
dest_location: S3ObjectLocation, | ||
): | ||
"""Something.""" | ||
self.region = region | ||
self.files = files | ||
self.dest_location = dest_location | ||
|
||
config = botocore.config.Config(max_pool_connections=self.WORKERS) | ||
self.s3_client = boto3.client("s3", region_name=self.region, config=config) | ||
transfer_config = s3transfer.TransferConfig( | ||
multipart_threshold=self.MULTIPART_CONFIG, | ||
multipart_chunksize=self.MULTIPART_CONFIG, | ||
max_bandwidth=True, | ||
use_threads=True, | ||
max_concurrency=self.WORKERS, | ||
) | ||
self.transfer_manager = s3transfer.create_transfer_manager( | ||
client=self.s3_client, config=transfer_config | ||
) | ||
|
||
def _copy_file(self, file: FileInfo, update_fn): | ||
"""Something.""" | ||
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} | ||
result = self.transfer_manager.copy( | ||
bucket=self.dest_location.bucket, | ||
key=f"{self.dest_location.key}/{file.location.key}", | ||
copy_source=copy_source, | ||
subscribers=[ | ||
s3transfer.ProgressCallbackInvoker(update_fn), | ||
], | ||
) | ||
result.result() | ||
|
||
def call(self): | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Something.""" | ||
total_size = sum([file.size for file in self.files]) | ||
JUMPSTART_LOGGER.warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think should be a warning. can we modify |
||
"Copying %s files (%s) into %s/%s", | ||
len(self.files), | ||
human_readable_size(total_size), | ||
self.dest_location.bucket, | ||
self.dest_location.key, | ||
) | ||
|
||
progress = tqdm.tqdm( | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
desc="JumpStart Sync", | ||
total=total_size, | ||
unit="B", | ||
unit_scale=1, | ||
position=0, | ||
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}", | ||
) | ||
|
||
for file in self.files: | ||
self._copy_file(file, progress.update) | ||
|
||
self.transfer_manager.shutdown() | ||
progress.close() | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this?