Skip to content

Commit 2eff8fb

Browse files
committed
support for gated and training unsupported
1 parent f50de6b commit 2eff8fb

File tree

8 files changed

+101
-20
lines changed

8 files changed

+101
-20
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
S3ObjectLocation,
2323
)
2424
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
25+
from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket
2526
from sagemaker.jumpstart.types import JumpStartModelSpecs
2627

2728

@@ -65,6 +66,10 @@ def generate_file_infos_from_model_specs(
6566
files = []
6667
for dependency in HubContentDependencyType:
6768
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
69+
# Training dependencies will return as None if training is unsupported
70+
if not location or is_gated_bucket(location.bucket):
71+
continue
72+
6873
location_type = "prefix" if location.key.endswith("/") else "object"
6974

7075
if location_type == "prefix":

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module accessors for the SageMaker JumpStart Public Hub."""
1414
from __future__ import absolute_import
15-
from typing import Dict, Any
15+
from typing import Dict, Any, Optional
1616
from sagemaker import model_uris, script_uris
1717
from sagemaker.jumpstart.curated_hub.types import (
1818
HubContentDependencyType,
@@ -21,7 +21,7 @@
2121
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
2222
from sagemaker.jumpstart.enums import JumpStartScriptScope
2323
from sagemaker.jumpstart.types import JumpStartModelSpecs
24-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_jumpstart_gated_content_bucket
2525

2626

2727
class PublicModelDataAccessor:
@@ -34,7 +34,11 @@ def __init__(
3434
studio_specs: Dict[str, Dict[str, Any]],
3535
):
3636
self._region = region
37-
self._bucket = get_jumpstart_content_bucket(region)
37+
self._bucket = (
38+
get_jumpstart_gated_content_bucket(region)
39+
if model_specs.gated_bucket
40+
else get_jumpstart_content_bucket(region)
41+
)
3842
self.model_specs = model_specs
3943
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift
4044

@@ -52,6 +56,8 @@ def inference_artifact_s3_reference(self):
5256
@property
5357
def training_artifact_s3_reference(self):
5458
"""Retrieves s3 reference for model training artifact"""
59+
if not self.model_specs.training_supported:
60+
return None
5561
return create_s3_object_reference_from_uri(
5662
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
5763
)
@@ -66,13 +72,17 @@ def inference_script_s3_reference(self):
6672
@property
6773
def training_script_s3_reference(self):
6874
"""Retrieves s3 reference for model training script"""
75+
if not self.model_specs.training_supported:
76+
return None
6977
return create_s3_object_reference_from_uri(
7078
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
7179
)
7280

7381
@property
7482
def default_training_dataset_s3_reference(self):
7583
"""Retrieves s3 reference for s3 directory containing model training datasets"""
84+
if not self.model_specs.training_supported:
85+
return None
7686
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
7787

7888
@property
@@ -95,22 +105,28 @@ def _get_bucket_name(self) -> str:
95105

96106
def __get_training_dataset_prefix(self) -> str:
97107
"""Retrieves training dataset location"""
98-
return self.studio_specs["defaultDataKey"]
108+
return self.studio_specs.get("defaultDataKey")
99109

100-
def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
110+
def _jumpstart_script_s3_uri(self, model_scope: str) -> Optional[str]:
101111
"""Retrieves JumpStart script s3 location"""
102-
return script_uris.retrieve(
103-
region=self._region,
104-
model_id=self.model_specs.model_id,
105-
model_version=self.model_specs.version,
106-
script_scope=model_scope,
107-
)
112+
try:
113+
return script_uris.retrieve(
114+
region=self._region,
115+
model_id=self.model_specs.model_id,
116+
model_version=self.model_specs.version,
117+
script_scope=model_scope,
118+
)
119+
except ValueError:
120+
return None
108121

109-
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
122+
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> Optional[str]:
110123
"""Retrieves JumpStart artifact s3 location"""
111-
return model_uris.retrieve(
112-
region=self._region,
113-
model_id=self.model_specs.model_id,
114-
model_version=self.model_specs.version,
115-
model_scope=model_scope,
116-
)
124+
try:
125+
return model_uris.retrieve(
126+
region=self._region,
127+
model_id=self.model_specs.model_id,
128+
model_version=self.model_specs.version,
129+
model_scope=model_scope,
130+
)
131+
except ValueError:
132+
return None

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,6 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
394394
self._sagemaker_session.import_hub_content(
395395
document_schema_version=HubContentDocument_v2.SCHEMA_VERSION,
396396
hub_content_name=model.model_id,
397-
hub_content_version=model.version,
398397
hub_name=self.hub_name,
399398
hub_content_document=hub_content_document,
400399
hub_content_type=HubContentType.MODEL,

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,11 @@ def generate_default_hub_bucket_name(
133133
return f"sagemaker-hubs-{region}-{account_id}"
134134

135135

136-
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
136+
def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
137137
"""Utiity to help generate an S3 object reference"""
138+
if not s3_uri:
139+
return None
140+
138141
bucket, key = parse_s3_url(s3_uri)
139142

140143
return S3ObjectLocation(
@@ -164,3 +167,8 @@ def create_hub_bucket_if_it_does_not_exist(
164167
)
165168

166169
return bucket_name
170+
171+
172+
def is_gated_bucket(bucket_name: str) -> bool:
173+
"""Returns true if the bucket name is the JumpStart gated bucket."""
174+
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,7 @@ def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str:
870870
"""Returns the Studio Spec file prefix given a model ID and version."""
871871
return f"studio_models/{model_id}/studio_specs_v{model_version}.json"
872872

873+
873874
def extract_info_from_hub_content_arn(
874875
arn: str,
875876
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:

tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,44 @@ def test_s3_path_file_generator_with_no_objects(s3_client):
127127

128128
s3_client.list_objects_v2.assert_called_once()
129129
assert response == []
130+
131+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
132+
def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client):
133+
specs = Mock()
134+
specs.model_id = "mock_model_123"
135+
specs.training_supported = False
136+
specs.gated_bucket = False
137+
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
138+
specs.hosting_script_key = "/my/inference/script.py"
139+
140+
response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)
141+
142+
assert response == [
143+
FileInfo(
144+
"jumpstart-cache-prod-us-west-2",
145+
"/my/inference/tarball.tgz",
146+
123456789,
147+
"08-14-1997 00:00:00",
148+
),
149+
FileInfo(
150+
"jumpstart-cache-prod-us-west-2",
151+
"/my/inference/script.py",
152+
123456789,
153+
"08-14-1997 00:00:00",
154+
),
155+
]
156+
157+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
158+
def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):
159+
specs = Mock()
160+
specs.model_id = "mock_model_123"
161+
specs.gated_bucket = True
162+
specs.training_supported = True
163+
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
164+
specs.hosting_script_key = "/my/inference/script.py"
165+
specs.training_prepacked_artifact_key = "/my/training/tarball.tgz"
166+
specs.training_script_key = "/my/training/script.py"
167+
168+
response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)
169+
170+
assert response == []

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,13 @@ def test_create_hub_bucket_if_it_does_not_exist():
238238

239239
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
240240
assert created_hub_bucket_name == bucket_name
241+
242+
243+
def test_is_gated_bucket():
244+
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True
245+
246+
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True
247+
248+
assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False
249+
250+
assert utils.is_gated_bucket("") is False

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
9898
)
9999
)
100100

101+
101102
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
102103
def test_jumpstart_proprietary_models_cache_get(mock_cache):
103104

0 commit comments

Comments
 (0)