Skip to content

MultiPartCopy with Sync Algorithm #4475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 12, 2024
11 changes: 8 additions & 3 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(
Default: None (no config).
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
used for SageMaker interactions. Default: Session in region associated with boto3 session.
used for SageMaker interactions. Default: Session in region associated with boto3
session.
"""

self._region = region
Expand Down Expand Up @@ -358,7 +359,9 @@ def _retrieval_function(
hub_content_type=data_type
)

model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
model_specs = JumpStartModelSpecs(
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
)

utils.emit_logs_based_on_model_specs(
model_specs,
Expand All @@ -372,7 +375,9 @@ def _retrieval_function(
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
hub_description = DescribeHubResponse(response)
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
return JumpStartCachedContentValue(
formatted_content=DescribeHubResponse(hub_description)
)
raise ValueError(
f"Bad value for key '{key}': must be in ",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"
Expand Down
Empty file.
112 changes: 112 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important utilities related to HubContent data files."""
from __future__ import absolute_import
from functools import singledispatchmethod
from typing import Any, Dict, List, Optional

from botocore.client import BaseClient

from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
from sagemaker.jumpstart.types import JumpStartModelSpecs


class FileGenerator:
"""Utility class to help format HubContent data files."""

def __init__(
self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None
):
self.region = region
self.s3_client = s3_client
self.studio_specs = studio_specs

@singledispatchmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen this annotation before, I assume it's lifted from aws s3 sync. Can we use a simpler implementation without these fancy annotations? Unless they're absolutely necessary, I feel like they make maintainability more difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya, this is just a really fancy way of implementing an if/else block. The true right way of doing this is through polymorphism where the incoming objects all implement a common interface (in this case, they all would define a method called .format() that you would be able to call). If that requires significant refactor, then if/else (or a case/switch in higher versions of python) is probably the better way to go

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the input I'll take a look at this. I wanted to take in a single input field that could be one of two types. Realistically, I can have two optional input fields and assert that at least one must be defined. Then I can branch my logic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Ben on this, I believe this could be a good practice long-term to avoid bloating our functions with optional fields. I'd argue that singledispatchmethod achieves polymorphism in a functional rather than OOP style. Instead of having a Factory that creates multiple classes that implement format(), Ben's implementation here cuts down on class boilerplate by method overloading format(input) directly with single-line decorators. IMO it's confusing because it's a new paradigm we're not used to yet, not because it's a bad way to implement it

def format(self, file_input) -> List[FileInfo]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?

"""Implement."""
# pylint: disable=W0107
pass

@format.register
def _(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Something."""
files = self.s3_format(file_input)
return files

@format.register
def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a way of implementing the @singledispatch function above. I wanted the .format function to take in one of two params and perform different actions. Essentially if/else block but this was nicer with input types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Something."""
files = self.specs_format(file_input, self.studio_specs)
return files

def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Retrieves data from a bucket and formats into FileInfo"""
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key}
response = self.s3_client.list_objects_v2(**parameters)
contents = response.get("Contents", None)

if not contents:
print("Nothing to download")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are fine with regular print statements?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder for now, I'll double check if we want to use logger (prob the case)

return []

files = []
for s3_obj in contents:
key: str = s3_obj.get("Key")
size: bytes = s3_obj.get("Size", None)
last_modified: str = s3_obj.get("LastModified", None)
files.append(FileInfo(key, size, last_modified))
return files

def specs_format(
self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any]
) -> List[FileInfo]:
"""Collects data locations from JumpStart public model specs and
converts into FileInfo.
"""
public_model_data_accessor = PublicModelDataAccessor(
region=self.region, model_specs=file_input, studio_specs=studio_specs
)
function_table = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename, hub_content_dependency_to_accessor_dict or something to that effect? And can we put in another common module?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the instantiated PublicModelDataAccessor, but I can likely store that as a constant in the class and use that to map to the function

HubContentDependencyType.INFERENCE_ARTIFACT: (
public_model_data_accessor.get_inference_artifact_s3_reference
),
HubContentDependencyType.TRAINING_ARTIFACT: (
public_model_data_accessor.get_training_artifact_s3_reference
),
HubContentDependencyType.INFERNECE_SCRIPT: (
public_model_data_accessor.get_inference_script_s3_reference
),
HubContentDependencyType.TRAINING_SCRIPT: (
public_model_data_accessor.get_training_script_s3_reference
),
HubContentDependencyType.DEFAULT_TRAINING_DATASET: (
public_model_data_accessor.get_default_training_dataset_s3_reference
),
HubContentDependencyType.DEMO_NOTEBOOK: (
public_model_data_accessor.get_demo_notebook_s3_reference
),
HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference,
}
files = []
for dependency in HubContentDependencyType:
location = function_table[dependency]()
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = self.s3_client.head_object(**parameters)
key: str = location.key
size: bytes = response.get("ContentLength", None)
last_updated: str = response.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(FileInfo(key, size, last_updated, dependency_type))
return files
47 changes: 47 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important details related to HubContent data files."""
from __future__ import absolute_import

from enum import Enum
from dataclasses import dataclass
from typing import Optional


class HubContentDependencyType(str, Enum):
"""Enum class for HubContent dependency names"""

INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT"
TRAINING_ARTIFACT = "TRAINING_ARTIFACT"
INFERNECE_SCRIPT = "INFERENCE_SCRIPT"
TRAINING_SCRIPT = "TRAINING_SCRIPT"
DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET"
DEMO_NOTEBOOK = "DEMO_NOTEBOOK"
MARKDOWN = "MARKDOWN"


@dataclass
class FileInfo:
"""Data class for additional S3 file info."""

def __init__(
self,
name: str,
size: Optional[bytes],
last_updated: Optional[str],
dependecy_type: Optional[HubContentDependencyType] = None,
):
self.name = name
self.size = size
self.last_updated = last_updated
self.dependecy_type = dependecy_type
46 changes: 46 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module utilites to assist S3 client calls for the Curated Hub."""
from __future__ import absolute_import
from dataclasses import dataclass
from typing import Dict


@dataclass
class S3ObjectLocation:
"""Helper class for S3 object references"""

bucket: str
key: str

def format_for_s3_copy(self) -> Dict[str, str]:
"""Returns a dict formatted for S3 copy calls"""
return {
"Bucket": self.bucket,
"Key": self.key,
}

def get_uri(self) -> str:
"""Returns the s3 URI"""
return f"s3://{self.bucket}/{self.key}"


def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
"""Utiity to help generate an S3 object reference"""
uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1)
uri_split = uri_with_s3_prefix_removed.split("/")

return S3ObjectLocation(
bucket=uri_split[0],
key="/".join(uri_split[1:]) if len(uri_split) > 1 else "",
)
111 changes: 111 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module accessors for the SageMaker JumpStart Public Hub."""
from __future__ import absolute_import
from typing import Dict, Any
from sagemaker import model_uris, script_uris
from sagemaker.jumpstart.curated_hub.utils import (
get_model_framework,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import JumpStartModelSpecs
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import (
S3ObjectLocation,
create_s3_object_reference_from_uri,
)


class PublicModelDataAccessor:
"""Accessor class for JumpStart model data s3 locations."""

def __init__(
self,
region: str,
model_specs: JumpStartModelSpecs,
studio_specs: Dict[str, Dict[str, Any]],
):
self._region = region
self._bucket = get_jumpstart_content_bucket(region)
self.model_specs = model_specs
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift

def get_bucket_name(self) -> str:
"""Retrieves s3 bucket"""
return self._bucket

def get_inference_artifact_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model inference artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
)

def get_training_artifact_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model training artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
)

def get_inference_script_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model inference script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
)

def get_training_script_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model training script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
)

def get_default_training_dataset_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for s3 directory containing model training datasets"""
return S3ObjectLocation(self.get_bucket_name(), self._get_training_dataset_prefix())

def _get_training_dataset_prefix(self) -> str:
"""Retrieves training dataset location"""
return self.studio_specs["defaultDataKey"]

def get_demo_notebook_s3_reference(self) -> S3ObjectLocation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for inference notebooks, no? Naming feels a little unclear

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's for the demo notebook. We can also store inference notebooks, but that is not found in metadata

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, why is the key key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" then?

"""Retrieves s3 reference for model demo jupyter notebook"""
framework = get_model_framework(self.model_specs)
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
return S3ObjectLocation(self.get_bucket_name(), key)

def get_markdown_s3_reference(self) -> S3ObjectLocation:
"""Retrieves s3 reference for model markdown"""
framework = get_model_framework(self.model_specs)
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
return S3ObjectLocation(self.get_bucket_name(), key)

def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart script s3 location"""
return script_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
script_scope=model_scope,
tolerate_vulnerable_model=True,
tolerate_deprecated_model=True,
)

def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart artifact s3 location"""
return model_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
model_scope=model_scope,
tolerate_vulnerable_model=True,
tolerate_deprecated_model=True,
)
Loading