Skip to content

Commit 27240a9

Browse files
committed
optimize function and add hub bucket prefix
1 parent 97001cc commit 27240a9

File tree

4 files changed

+74
-58
lines changed

4 files changed

+74
-58
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ def generate_file_infos_from_s3_location(
3535
"""
3636
parameters = {"Bucket": location.bucket, "Prefix": location.key}
3737
response = s3_client.list_objects_v2(**parameters)
38-
contents = response.get("Contents", None)
38+
contents = response.get("Contents")
3939

4040
if not contents:
4141
return []
4242

4343
files = []
4444
for s3_obj in contents:
45-
key: str = s3_obj.get("Key")
46-
size: bytes = s3_obj.get("Size", None)
47-
last_modified: str = s3_obj.get("LastModified", None)
45+
key = s3_obj.get("Key")
46+
size = s3_obj.get("Size")
47+
last_modified = s3_obj.get("LastModified")
4848
files.append(FileInfo(location.bucket, key, size, last_modified))
4949
return files
5050

@@ -71,28 +71,26 @@ def generate_file_infos_from_model_specs(
7171
if location_type == "prefix":
7272
parameters = {"Bucket": location.bucket, "Prefix": location.key}
7373
response = s3_client.list_objects_v2(**parameters)
74-
contents = response.get("Contents", None)
74+
contents = response.get("Contents")
7575
for s3_obj in contents:
76-
key: str = s3_obj.get("Key")
77-
size: bytes = s3_obj.get("Size", None)
78-
last_modified: datetime = s3_obj.get("LastModified", None)
79-
dependency_type: HubContentDependencyType = dependency
76+
key = s3_obj.get("Key")
77+
size = s3_obj.get("Size")
78+
last_modified = s3_obj.get("LastModified")
8079
files.append(
8180
FileInfo(
8281
location.bucket,
8382
key,
8483
size,
8584
last_modified,
86-
dependency_type,
85+
dependency,
8786
)
8887
)
8988
elif location_type == "object":
9089
parameters = {"Bucket": location.bucket, "Key": location.key}
9190
response = s3_client.head_object(**parameters)
92-
size: bytes = response.get("ContentLength", None)
93-
last_updated: datetime = response.get("LastModified", None)
94-
dependency_type: HubContentDependencyType = dependency
91+
size = response.get("ContentLength")
92+
last_updated = response.get("LastModified")
9593
files.append(
96-
FileInfo(location.bucket, location.key, size, last_updated, dependency_type)
94+
FileInfo(location.bucket, location.key, size, last_updated, dependency)
9795
)
9896
return files

src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class MultiPartCopyHandler(object):
6262
"""Multi Part Copy Handler class."""
6363

6464
WORKERS = 20
65+
# Config values from in S3:Copy
6566
MULTIPART_CONFIG = 8 * (1024**2)
6667

6768
def __init__(

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module provides the JumpStart Curated Hub class."""
1414
from __future__ import absolute_import
1515
from concurrent import futures
16+
from datetime import datetime
1617
import json
1718
import traceback
1819
from typing import Optional, Dict, List, Any
@@ -50,6 +51,7 @@
5051
HubContentDocument_v2,
5152
JumpStartModelInfo,
5253
S3ObjectLocation,
54+
create_s3_object_reference_from_uri,
5355
)
5456

5557

@@ -73,20 +75,21 @@ def __init__(
7375
self.region = sagemaker_session.boto_region_name
7476
self._sagemaker_session = sagemaker_session
7577
self._default_thread_pool_size = 20
76-
self.hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
7778
self._s3_client = self._get_s3_client()
79+
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)
7880

7981
def _get_s3_client(self) -> BaseClient:
80-
"""Returns an S3 client."""
82+
"""Returns an S3 client used for creating a HubContentDocument."""
8183
return boto3.client("s3", region_name=self.region)
8284

8385
def _fetch_hub_bucket_name(self) -> str:
8486
"""Retrieves hub bucket name from Hub config if exists"""
8587
try:
8688
hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name)
87-
hub_bucket_prefix = hub_response["S3StorageConfig"].get("S3OutputPath", None)
88-
if hub_bucket_prefix:
89-
return hub_bucket_prefix.replace("s3://", "")
89+
hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath")
90+
if hub_output_location:
91+
location = create_s3_object_reference_from_uri(hub_output_location)
92+
return location.bucket
9093
default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session)
9194
JUMPSTART_LOGGER.warning(
9295
"There is not a Hub bucket associated with %s. Using %s",
@@ -103,6 +106,12 @@ def _fetch_hub_bucket_name(self) -> str:
103106
)
104107
return hub_bucket_name
105108

109+
def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None:
110+
"""Generates an ``S3ObjectLocation`` given a Hub name."""
111+
hub_bucket_name = bucket_name or self._fetch_hub_bucket_name()
112+
curr_timestamp = datetime.now().timestamp()
113+
return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}")
114+
106115
def create(
107116
self,
108117
description: str,
@@ -112,16 +121,16 @@ def create(
112121
) -> Dict[str, str]:
113122
"""Creates a hub with the given description"""
114123

115-
bucket_name = create_hub_bucket_if_it_does_not_exist(
116-
self.hub_bucket_name, self._sagemaker_session
124+
create_hub_bucket_if_it_does_not_exist(
125+
self.hub_storage_location.bucket, self._sagemaker_session
117126
)
118127

119128
return self._sagemaker_session.create_hub(
120129
hub_name=self.hub_name,
121130
hub_description=description,
122131
hub_display_name=display_name,
123132
hub_search_keywords=search_keywords,
124-
s3_storage_config={"S3OutputPath": f"s3://{bucket_name}"},
133+
s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()},
125134
tags=tags,
126135
)
127136

@@ -226,7 +235,7 @@ def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]:
226235
return js_models_in_hub
227236

228237
def _determine_models_to_sync(
229-
self, model_list: List[JumpStartModelInfo], models_in_hub
238+
self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, Any]
230239
) -> List[JumpStartModelInfo]:
231240
"""Determines which models from `sync` params to sync into the CuratedHub.
232241
@@ -240,14 +249,7 @@ def _determine_models_to_sync(
240249
"""
241250
models_to_sync = []
242251
for model in model_list:
243-
matched_model = next(
244-
(
245-
hub_model
246-
for hub_model in models_in_hub
247-
if hub_model and hub_model["name"] == model.model_id
248-
),
249-
None,
250-
)
252+
matched_model = models_in_hub.get(model.model_id)
251253

252254
# Model does not exist in Hub, sync
253255
if not matched_model:
@@ -300,8 +302,9 @@ def sync(self, model_list: List[Dict[str, str]]):
300302
model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"]))
301303

302304
js_models_in_hub = self._get_jumpstart_models_in_hub()
305+
mapped_models_in_hub = { model["name"]: model for model in js_models_in_hub }
303306

304-
models_to_sync = self._determine_models_to_sync(model_version_list, js_models_in_hub)
307+
models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub)
305308
JUMPSTART_LOGGER.warning(
306309
"Syncing the following models into Hub %s: %s", self.hub_name, models_to_sync
307310
)
@@ -349,7 +352,8 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
349352
studio_specs = self._fetch_studio_specs(model_specs=model_specs)
350353

351354
dest_location = S3ObjectLocation(
352-
bucket=self.hub_bucket_name, key=f"{model.model_id}/{model.version}"
355+
bucket=self.hub_storage_location.bucket,
356+
key=f"{self.hub_storage_location.key}/{model.model_id}/{model.version}"
353357
)
354358
src_files = file_generator.generate_file_infos_from_model_specs(
355359
model_specs, studio_specs, self.region, self._s3_client

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,25 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
from copy import deepcopy
15+
import datetime
1516
from unittest import mock
1617
from unittest.mock import patch
1718
import pytest
1819
from mock import Mock
1920
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
20-
from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo
21+
from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo, S3ObjectLocation
2122
from sagemaker.jumpstart.types import JumpStartModelSpecs
2223
from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC
2324
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
2425

26+
2527
REGION = "us-east-1"
2628
ACCOUNT_ID = "123456789123"
2729
HUB_NAME = "mock-hub-name"
2830

2931
MODULE_PATH = "sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub"
3032

33+
FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00)
3134

3235
@pytest.fixture()
3336
def sagemaker_session():
@@ -39,7 +42,7 @@ def sagemaker_session():
3942
"Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource"
4043
)
4144
sagemaker_session_mock.describe_hub.return_value = {
42-
"S3StorageConfig": {"S3OutputPath": "mock-bucket-123"}
45+
"S3StorageConfig": {"S3OutputPath": "s3://mock-bucket-123"}
4346
}
4447
sagemaker_session_mock.account_id.return_value = ACCOUNT_ID
4548
return sagemaker_session_mock
@@ -66,7 +69,9 @@ def test_instantiates(sagemaker_session):
6669
),
6770
],
6871
)
72+
@patch("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location")
6973
def test_create_with_no_bucket_name(
74+
mock_generate_hub_storage_location,
7075
sagemaker_session,
7176
hub_name,
7277
hub_description,
@@ -75,18 +80,22 @@ def test_create_with_no_bucket_name(
7580
hub_search_keywords,
7681
tags,
7782
):
83+
storage_location = S3ObjectLocation("sagemaker-hubs-us-east-1-123456789123",f"{hub_name}-{FAKE_TIME.timestamp()}")
84+
mock_generate_hub_storage_location.return_value = storage_location
7885
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
7986
sagemaker_session.create_hub = Mock(return_value=create_hub)
8087
sagemaker_session.describe_hub.return_value = {
81-
"S3StorageConfig": {"S3OutputPath": hub_bucket_name}
88+
"S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"}
8289
}
8390
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
8491
request = {
8592
"hub_name": hub_name,
8693
"hub_description": hub_description,
8794
"hub_display_name": hub_display_name,
8895
"hub_search_keywords": hub_search_keywords,
89-
"s3_storage_config": {"S3OutputPath": "s3://sagemaker-hubs-us-east-1-123456789123"},
96+
"s3_storage_config": {
97+
"S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}"
98+
},
9099
"tags": tags,
91100
}
92101
response = hub.create(
@@ -113,7 +122,9 @@ def test_create_with_no_bucket_name(
113122
),
114123
],
115124
)
125+
@patch("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location")
116126
def test_create_with_bucket_name(
127+
mock_generate_hub_storage_location,
117128
sagemaker_session,
118129
hub_name,
119130
hub_description,
@@ -122,6 +133,8 @@ def test_create_with_bucket_name(
122133
hub_search_keywords,
123134
tags,
124135
):
136+
storage_location = S3ObjectLocation(hub_bucket_name,f"{hub_name}-{FAKE_TIME.timestamp()}")
137+
mock_generate_hub_storage_location.return_value = storage_location
125138
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
126139
sagemaker_session.create_hub = Mock(return_value=create_hub)
127140
hub = CuratedHub(
@@ -132,7 +145,7 @@ def test_create_with_bucket_name(
132145
"hub_description": hub_description,
133146
"hub_display_name": hub_display_name,
134147
"hub_search_keywords": hub_search_keywords,
135-
"s3_storage_config": {"S3OutputPath": "s3://mock-bucket-123"},
148+
"s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"},
136149
"tags": tags,
137150
}
138151
response = hub.create(
@@ -397,92 +410,92 @@ def test_determine_models_to_sync(sagemaker_session):
397410
hub_name = "mock_hub_name"
398411
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
399412

400-
js_models_in_hub = [
401-
{
413+
js_model_map = {
414+
"mock-model-two-pytorch": {
402415
"name": "mock-model-two-pytorch",
403416
"version": "1.0.1",
404417
"search_keywords": [
405418
"@jumpstart-model-id:model-two-pytorch",
406419
"@jumpstart-model-version:1.0.2",
407420
],
408421
},
409-
{
422+
"mock-model-four-huggingface": {
410423
"name": "mock-model-four-huggingface",
411424
"version": "2.0.2",
412425
"search_keywords": [
413426
"@jumpstart-model-id:model-four-huggingface",
414427
"@jumpstart-model-version:2.0.2",
415428
],
416429
},
417-
]
430+
}
418431
model_one = JumpStartModelInfo("mock-model-one-huggingface", "1.2.3")
419432
model_two = JumpStartModelInfo("mock-model-two-pytorch", "1.0.2")
420433
# No model_one, older model_two
421-
res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub)
434+
res = hub._determine_models_to_sync([model_one, model_two], js_model_map)
422435
assert res == [model_one, model_two]
423436

424-
js_models_in_hub = [
425-
{
437+
js_model_map = {
438+
"mock-model-two-pytorch": {
426439
"name": "mock-model-two-pytorch",
427440
"version": "1.0.3",
428441
"search_keywords": [
429442
"@jumpstart-model-id:model-two-pytorch",
430443
"@jumpstart-model-version:1.0.3",
431444
],
432445
},
433-
{
446+
"mock-model-four-huggingface": {
434447
"name": "mock-model-four-huggingface",
435448
"version": "2.0.2",
436449
"search_keywords": [
437450
"@jumpstart-model-id:model-four-huggingface",
438451
"@jumpstart-model-version:2.0.2",
439452
],
440453
},
441-
]
454+
}
442455
# No model_one, newer model_two
443-
res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub)
456+
res = hub._determine_models_to_sync([model_one, model_two], js_model_map)
444457
assert res == [model_one]
445458

446-
js_models_in_hub = [
447-
{
459+
js_model_map = {
460+
"mock-model-one-huggingface": {
448461
"name": "mock-model-one-huggingface",
449462
"version": "1.2.3",
450463
"search_keywords": [
451464
"@jumpstart-model-id:model-one-huggingface",
452465
"@jumpstart-model-version:1.2.3",
453466
],
454467
},
455-
{
468+
"mock-model-two-pytorch": {
456469
"name": "mock-model-two-pytorch",
457470
"version": "1.0.2",
458471
"search_keywords": [
459472
"@jumpstart-model-id:model-two-pytorch",
460473
"@jumpstart-model-version:1.0.2",
461474
],
462475
},
463-
]
476+
}
464477
# Same model_one, same model_two
465-
res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub)
478+
res = hub._determine_models_to_sync([model_one, model_two], js_model_map)
466479
assert res == []
467480

468-
js_models_in_hub = [
469-
{
481+
js_model_map = {
482+
"mock-model-one-huggingface": {
470483
"name": "mock-model-one-huggingface",
471484
"version": "1.2.1",
472485
"search_keywords": [
473486
"@jumpstart-model-id:model-one-huggingface",
474487
"@jumpstart-model-version:1.2.1",
475488
],
476489
},
477-
{
490+
"mock-model-two-pytorch": {
478491
"name": "mock-model-two-pytorch",
479492
"version": "1.0.2",
480493
"search_keywords": [
481494
"@jumpstart-model-id:model-two-pytorch",
482495
"@jumpstart-model-version:1.0.2",
483496
],
484497
},
485-
]
498+
}
486499
# Old model_one, same model_two
487-
res = hub._determine_models_to_sync([model_one, model_two], js_models_in_hub)
500+
res = hub._determine_models_to_sync([model_one, model_two], js_model_map)
488501
assert res == [model_one]

0 commit comments

Comments
 (0)