Skip to content

Commit 344d26b

Browse files
committed
first pass at sync function with util classes
1 parent 352a5c1 commit 344d26b

17 files changed

+881
-17
lines changed

src/sagemaker/jumpstart/cache.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def __init__(
101101
Default: None (no config).
102102
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
103103
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
104-
used for SageMaker interactions. Default: Session in region associated with boto3 session.
104+
used for SageMaker interactions. Default: Session in region associated with boto3
105+
session.
105106
"""
106107

107108
self._region = region
@@ -358,7 +359,9 @@ def _retrieval_function(
358359
hub_content_type=data_type
359360
)
360361

361-
model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
362+
model_specs = JumpStartModelSpecs(
363+
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
364+
)
362365

363366
utils.emit_logs_based_on_model_specs(
364367
model_specs,
@@ -372,7 +375,9 @@ def _retrieval_function(
372375
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
373376
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
374377
hub_description = DescribeHubResponse(response)
375-
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
378+
return JumpStartCachedContentValue(
379+
formatted_content=DescribeHubResponse(hub_description)
380+
)
376381
raise ValueError(
377382
f"Bad value for key '{key}': must be in ",
378383
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"

src/sagemaker/jumpstart/curated_hub/accessors/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains important utilities related to HubContent data files."""
14+
from __future__ import absolute_import
15+
from functools import singledispatchmethod
16+
from typing import Any, Dict, List, Optional
17+
18+
from botocore.client import BaseClient
19+
20+
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType
21+
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation
22+
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
23+
from sagemaker.jumpstart.types import JumpStartModelSpecs
24+
25+
26+
class FileGenerator:
27+
"""Utility class to help format HubContent data files."""
28+
29+
def __init__(
30+
self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None
31+
):
32+
self.region = region
33+
self.s3_client = s3_client
34+
self.studio_specs = studio_specs
35+
36+
@singledispatchmethod
37+
def format(self, file_input) -> List[FileInfo]:
38+
"""Implement."""
39+
# pylint: disable=W0107
40+
pass
41+
42+
@format.register
43+
def _(self, file_input: S3ObjectLocation) -> List[FileInfo]:
44+
"""Something."""
45+
files = self.s3_format(file_input)
46+
return files
47+
48+
@format.register
49+
def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]:
50+
"""Something."""
51+
files = self.specs_format(file_input, self.studio_specs)
52+
return files
53+
54+
def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
55+
"""Retrieves data from a bucket and formats into FileInfo"""
56+
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key}
57+
response = self.s3_client.list_objects_v2(**parameters)
58+
contents = response.get("Contents", None)
59+
60+
if not contents:
61+
print("Nothing to download")
62+
return []
63+
64+
files = []
65+
for s3_obj in contents:
66+
key: str = s3_obj.get("Key")
67+
size: bytes = s3_obj.get("Size", None)
68+
last_modified: str = s3_obj.get("LastModified", None)
69+
files.append(FileInfo(key, size, last_modified))
70+
return files
71+
72+
def specs_format(
73+
self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any]
74+
) -> List[FileInfo]:
75+
"""Collects data locations from JumpStart public model specs and
76+
converts into FileInfo.
77+
"""
78+
public_model_data_accessor = PublicModelDataAccessor(
79+
region=self.region, model_specs=file_input, studio_specs=studio_specs
80+
)
81+
function_table = {
82+
HubContentDependencyType.INFERENCE_ARTIFACT: (
83+
public_model_data_accessor.get_inference_artifact_s3_reference
84+
),
85+
HubContentDependencyType.TRAINING_ARTIFACT: (
86+
public_model_data_accessor.get_training_artifact_s3_reference
87+
),
88+
HubContentDependencyType.INFERNECE_SCRIPT: (
89+
public_model_data_accessor.get_inference_script_s3_reference
90+
),
91+
HubContentDependencyType.TRAINING_SCRIPT: (
92+
public_model_data_accessor.get_training_script_s3_reference
93+
),
94+
HubContentDependencyType.DEFAULT_TRAINING_DATASET: (
95+
public_model_data_accessor.get_default_training_dataset_s3_reference
96+
),
97+
HubContentDependencyType.DEMO_NOTEBOOK: (
98+
public_model_data_accessor.get_demo_notebook_s3_reference
99+
),
100+
HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference,
101+
}
102+
files = []
103+
for dependency in HubContentDependencyType:
104+
location = function_table[dependency]()
105+
parameters = {"Bucket": location.bucket, "Prefix": location.key}
106+
response = self.s3_client.head_object(**parameters)
107+
key: str = location.key
108+
size: bytes = response.get("ContentLength", None)
109+
last_updated: str = response.get("LastModified", None)
110+
dependency_type: HubContentDependencyType = dependency
111+
files.append(FileInfo(key, size, last_updated, dependency_type))
112+
return files
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains important details related to HubContent data files."""
14+
from __future__ import absolute_import
15+
16+
from enum import Enum
17+
from dataclasses import dataclass
18+
from typing import Optional
19+
20+
21+
class HubContentDependencyType(str, Enum):
22+
"""Enum class for HubContent dependency names"""
23+
24+
INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT"
25+
TRAINING_ARTIFACT = "TRAINING_ARTIFACT"
26+
INFERNECE_SCRIPT = "INFERENCE_SCRIPT"
27+
TRAINING_SCRIPT = "TRAINING_SCRIPT"
28+
DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET"
29+
DEMO_NOTEBOOK = "DEMO_NOTEBOOK"
30+
MARKDOWN = "MARKDOWN"
31+
32+
33+
@dataclass
34+
class FileInfo:
35+
"""Data class for additional S3 file info."""
36+
37+
def __init__(
38+
self,
39+
name: str,
40+
size: Optional[bytes],
41+
last_updated: Optional[str],
42+
dependecy_type: Optional[HubContentDependencyType] = None,
43+
):
44+
self.name = name
45+
self.size = size
46+
self.last_updated = last_updated
47+
self.dependecy_type = dependecy_type
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module utilites to assist S3 client calls for the Curated Hub."""
14+
from __future__ import absolute_import
15+
from dataclasses import dataclass
16+
from typing import Dict
17+
18+
19+
@dataclass
20+
class S3ObjectLocation:
21+
"""Helper class for S3 object references"""
22+
23+
bucket: str
24+
key: str
25+
26+
def format_for_s3_copy(self) -> Dict[str, str]:
27+
"""Returns a dict formatted for S3 copy calls"""
28+
return {
29+
"Bucket": self.bucket,
30+
"Key": self.key,
31+
}
32+
33+
def get_uri(self) -> str:
34+
"""Returns the s3 URI"""
35+
return f"s3://{self.bucket}/{self.key}"
36+
37+
38+
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
39+
"""Utiity to help generate an S3 object reference"""
40+
uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1)
41+
uri_split = uri_with_s3_prefix_removed.split("/")
42+
43+
return S3ObjectLocation(
44+
bucket=uri_split[0],
45+
key="/".join(uri_split[1:]) if len(uri_split) > 1 else "",
46+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module accessors for the SageMaker JumpStart Public Hub."""
14+
from __future__ import absolute_import
15+
from typing import Dict, Any
16+
from sagemaker import model_uris, script_uris
17+
from sagemaker.jumpstart.curated_hub.utils import (
18+
get_model_framework,
19+
)
20+
from sagemaker.jumpstart.enums import JumpStartScriptScope
21+
from sagemaker.jumpstart.types import JumpStartModelSpecs
22+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
23+
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import (
24+
S3ObjectLocation,
25+
create_s3_object_reference_from_uri,
26+
)
27+
28+
29+
class PublicModelDataAccessor:
30+
"""Accessor class for JumpStart model data s3 locations."""
31+
32+
def __init__(
33+
self,
34+
region: str,
35+
model_specs: JumpStartModelSpecs,
36+
studio_specs: Dict[str, Dict[str, Any]],
37+
):
38+
self._region = region
39+
self._bucket = get_jumpstart_content_bucket(region)
40+
self.model_specs = model_specs
41+
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift
42+
43+
def get_bucket_name(self) -> str:
44+
"""Retrieves s3 bucket"""
45+
return self._bucket
46+
47+
def get_inference_artifact_s3_reference(self) -> S3ObjectLocation:
48+
"""Retrieves s3 reference for model inference artifact"""
49+
return create_s3_object_reference_from_uri(
50+
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
51+
)
52+
53+
def get_training_artifact_s3_reference(self) -> S3ObjectLocation:
54+
"""Retrieves s3 reference for model training artifact"""
55+
return create_s3_object_reference_from_uri(
56+
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
57+
)
58+
59+
def get_inference_script_s3_reference(self) -> S3ObjectLocation:
60+
"""Retrieves s3 reference for model inference script"""
61+
return create_s3_object_reference_from_uri(
62+
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
63+
)
64+
65+
def get_training_script_s3_reference(self) -> S3ObjectLocation:
66+
"""Retrieves s3 reference for model training script"""
67+
return create_s3_object_reference_from_uri(
68+
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
69+
)
70+
71+
def get_default_training_dataset_s3_reference(self) -> S3ObjectLocation:
72+
"""Retrieves s3 reference for s3 directory containing model training datasets"""
73+
return S3ObjectLocation(self.get_bucket_name(), self._get_training_dataset_prefix())
74+
75+
def _get_training_dataset_prefix(self) -> str:
76+
"""Retrieves training dataset location"""
77+
return self.studio_specs["defaultDataKey"]
78+
79+
def get_demo_notebook_s3_reference(self) -> S3ObjectLocation:
80+
"""Retrieves s3 reference for model demo jupyter notebook"""
81+
framework = get_model_framework(self.model_specs)
82+
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
83+
return S3ObjectLocation(self.get_bucket_name(), key)
84+
85+
def get_markdown_s3_reference(self) -> S3ObjectLocation:
86+
"""Retrieves s3 reference for model markdown"""
87+
framework = get_model_framework(self.model_specs)
88+
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
89+
return S3ObjectLocation(self.get_bucket_name(), key)
90+
91+
def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
92+
"""Retrieves JumpStart script s3 location"""
93+
return script_uris.retrieve(
94+
region=self._region,
95+
model_id=self.model_specs.model_id,
96+
model_version=self.model_specs.version,
97+
script_scope=model_scope,
98+
tolerate_vulnerable_model=True,
99+
tolerate_deprecated_model=True,
100+
)
101+
102+
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
103+
"""Retrieves JumpStart artifact s3 location"""
104+
return model_uris.retrieve(
105+
region=self._region,
106+
model_id=self.model_specs.model_id,
107+
model_version=self.model_specs.version,
108+
model_scope=model_scope,
109+
tolerate_vulnerable_model=True,
110+
tolerate_deprecated_model=True,
111+
)

0 commit comments

Comments
 (0)