Skip to content

MultiPartCopy with Sync Algorithm #4475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 12, 2024
11 changes: 8 additions & 3 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(
Default: None (no config).
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
used for SageMaker interactions. Default: Session in region associated with boto3 session.
used for SageMaker interactions. Default: Session in region associated with boto3
session.
"""

self._region = region
Expand Down Expand Up @@ -358,7 +359,9 @@ def _retrieval_function(
hub_content_type=data_type
)

model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
model_specs = JumpStartModelSpecs(
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
)

utils.emit_logs_based_on_model_specs(
model_specs,
Expand All @@ -372,7 +375,9 @@ def _retrieval_function(
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
hub_description = DescribeHubResponse(response)
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
return JumpStartCachedContentValue(
formatted_content=DescribeHubResponse(hub_description)
)
raise ValueError(
f"Bad value for key '{key}': must be in ",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"
Expand Down
Empty file.
93 changes: 93 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important utilities related to HubContent data files."""
from __future__ import absolute_import
from typing import Any, Dict, List

from botocore.client import BaseClient

from sagemaker.jumpstart.curated_hub.types import (
FileInfo,
HubContentDependencyType,
S3ObjectLocation,
)
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
from sagemaker.jumpstart.types import JumpStartModelSpecs


def generate_file_infos_from_s3_location(
location: S3ObjectLocation, s3_client: BaseClient
) -> List[FileInfo]:
"""Lists objects from an S3 bucket and formats into FileInfo.

Returns a list of ``FileInfo`` objects from the specified bucket location.
"""
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = s3_client.list_objects_v2(**parameters)
contents = response.get("Contents")

if not contents:
return []

files = []
for s3_obj in contents:
key = s3_obj.get("Key")
size = s3_obj.get("Size")
last_modified = s3_obj.get("LastModified")
files.append(FileInfo(location.bucket, key, size, last_modified))
return files


def generate_file_infos_from_model_specs(
model_specs: JumpStartModelSpecs,
studio_specs: Dict[str, Any],
region: str,
s3_client: BaseClient,
) -> List[FileInfo]:
"""Collects data locations from JumpStart public model specs and converts into `FileInfo`.

Returns a list of `FileInfo` objects from dependencies found in the public
model specs.
"""
public_model_data_accessor = PublicModelDataAccessor(
region=region, model_specs=model_specs, studio_specs=studio_specs
)
files = []
for dependency in HubContentDependencyType:
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
location_type = "prefix" if location.key.endswith("/") else "object"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking: that's a bit of a shortcut, ideally we would try catch s3 head_object or call s3 list_objects_v2 --max-results 1 to resolve.


if location_type == "prefix":
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = s3_client.list_objects_v2(**parameters)
contents = response.get("Contents")
for s3_obj in contents:
key = s3_obj.get("Key")
size = s3_obj.get("Size")
last_modified = s3_obj.get("LastModified")
files.append(
Comment on lines +70 to +78
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonblocking: lot of indentation depth here, consider moving to helper function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i also see some duplicated code

FileInfo(
location.bucket,
key,
size,
last_modified,
dependency,
)
)
elif location_type == "object":
parameters = {"Bucket": location.bucket, "Key": location.key}
response = s3_client.head_object(**parameters)
size = response.get("ContentLength")
last_updated = response.get("LastModified")
files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency))
return files
143 changes: 143 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``."""
from __future__ import absolute_import
from typing import Optional

import boto3
import botocore
import tqdm

from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.curated_hub.types import FileInfo
from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequest

s3transfer = boto3.s3.transfer


# pylint: disable=R1705,R1710
def human_readable_size(value: int) -> str:
"""Convert a size in bytes into a human readable format.

For example::

>>> human_readable_size(1)
'1 Byte'
>>> human_readable_size(10)
'10 Bytes'
>>> human_readable_size(1024)
'1.0 KiB'
>>> human_readable_size(1024 * 1024)
'1.0 MiB'

:param value: The size in bytes.
:return: The size in a human readable format based on base-2 units.

"""
base = 1024
bytes_int = float(value)

if bytes_int == 1:
return "1 Byte"
elif bytes_int < base:
return "%d Bytes" % bytes_int

for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")):
unit = base ** (i + 2)
if round((bytes_int / unit) * base) < base:
return "%.1f %s" % ((base * bytes_int / unit), suffix)


class MultiPartCopyHandler(object):
"""Multi Part Copy Handler class."""

WORKERS = 20
# Config values from in S3:Copy
MULTIPART_CONFIG = 8 * (1024**2)

def __init__(
self,
region: str,
sync_request: HubSyncRequest,
label: Optional[str] = None,
thread_num: Optional[int] = 0,
):
"""Multi-part S3:Copy Handler initializer.

Args:
region (str): Region for the S3 Client
sync_request (HubSyncRequest): sync request object containing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Seems like the docstring is out of date?

information required to perform the copy
"""
self.label = label
self.region = region
self.files = sync_request.files
self.dest_location = sync_request.destination
self.thread_num = thread_num
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: This value seems to be only used in tqdm, but from seeing this field I would assume it would control the thread count for s3transfer. Was thread_count = 20 the intent or am I misreading?


config = botocore.config.Config(max_pool_connections=self.WORKERS)
self.s3_client = boto3.client("s3", region_name=self.region, config=config)
transfer_config = s3transfer.TransferConfig(
multipart_threshold=self.MULTIPART_CONFIG,
multipart_chunksize=self.MULTIPART_CONFIG,
max_bandwidth=True,
use_threads=True,
max_concurrency=self.WORKERS,
)
self.transfer_manager = s3transfer.create_transfer_manager(
client=self.s3_client, config=transfer_config
)

def _copy_file(self, file: FileInfo, progress_cb):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add typing to progress_cb

"""Performs the actual MultiPart S3:Copy of the object."""
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key}
result = self.transfer_manager.copy(
bucket=self.dest_location.bucket,
key=f"{self.dest_location.key}/{file.location.key}",
copy_source=copy_source,
subscribers=[
s3transfer.ProgressCallbackInvoker(progress_cb),
],
)
# Attempt to access result to throw error if exists. Silently calls if successful.
result.result()

def execute(self):
"""Executes the MultiPart S3:Copy on the class.

Sets up progress bar and kicks off each copy request.
"""
total_size = sum([file.size for file in self.files])
JUMPSTART_LOGGER.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think should be a warning. can we modify JUMPSTART_LOGGER so that if you doJUMPSTART_LOGGER.info("blah", stdout=True), then it gets printed? warning gives a negative connotation

"Copying %s files (%s) into %s/%s",
len(self.files),
human_readable_size(total_size),
self.dest_location.bucket,
self.dest_location.key,
)

progress = tqdm.tqdm(
desc=self.label,
total=total_size,
unit="B",
unit_scale=1,
position=self.thread_num,
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}",
)

for file in self.files:
self._copy_file(file, progress.update)

# Call `shutdown` to wait for copy results
self.transfer_manager.shutdown()
progress.close()
116 changes: 116 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module accessors for the SageMaker JumpStart Public Hub."""
from __future__ import absolute_import
from typing import Dict, Any
from sagemaker import model_uris, script_uris
from sagemaker.jumpstart.curated_hub.types import (
HubContentDependencyType,
S3ObjectLocation,
)
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import JumpStartModelSpecs
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket


class PublicModelDataAccessor:
"""Accessor class for JumpStart model data s3 locations."""

def __init__(
self,
region: str,
model_specs: JumpStartModelSpecs,
studio_specs: Dict[str, Dict[str, Any]],
):
self._region = region
self._bucket = get_jumpstart_content_bucket(region)
self.model_specs = model_specs
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift

def get_s3_reference(self, dependency_type: HubContentDependencyType):
"""Retrieves S3 reference given a HubContentDependencyType."""
return getattr(self, dependency_type.value)

@property
def inference_artifact_s3_reference(self):
"""Retrieves s3 reference for model inference artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_artifact_s3_reference(self):
"""Retrieves s3 reference for model training artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
)

@property
def inference_script_s3_reference(self):
"""Retrieves s3 reference for model inference script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_script_s3_reference(self):
"""Retrieves s3 reference for model training script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
)

@property
def default_training_dataset_s3_reference(self):
"""Retrieves s3 reference for s3 directory containing model training datasets"""
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there 2 underscores for self.__get_training_dataset_prefix?


@property
def demo_notebook_s3_reference(self):
"""Retrieves s3 reference for model demo jupyter notebook"""
framework = self.model_specs.get_framework()
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
return S3ObjectLocation(self._get_bucket_name(), key)

@property
Comment on lines +77 to +85
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im worried we could change the s3 file organization and this would break

def markdown_s3_reference(self):
"""Retrieves s3 reference for model markdown"""
framework = self.model_specs.get_framework()
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
return S3ObjectLocation(self._get_bucket_name(), key)

def _get_bucket_name(self) -> str:
"""Retrieves s3 bucket"""
return self._bucket

def __get_training_dataset_prefix(self) -> str:
"""Retrieves training dataset location"""
return self.studio_specs["defaultDataKey"]

def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart script s3 location"""
return script_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
script_scope=model_scope,
)

def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart artifact s3 location"""
return model_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
model_scope=model_scope,
)
Loading