Skip to content

Commit a632795

Browse files
bencrabtreeqidewenwhenCaptainialiujiaorrjinyoung-lim
authored
Feat/gated model support (#4510)
* fix: Move sagemaker pysdk version check after bootstrap in remote job (#4487) * feat: support JumpStart proprietary models (#4467) * feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <[email protected]> * feat: add hub and hubcontent support in retrieval function for jumpstart model cache (#4438) * feat: jsch jumpstart estimator support (#4439) * Master jumpstart curated hub (#4464) * add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor (#4463) * feature: JumpStart CuratedHub class creation and function definitions (#4448) * MultiPartCopy with Sync Algorithm (#4475) * first pass at sync function with util classes * adding tests and update clases * linting * file generator class inheritance * lint * multipart copy and algorithm updates * modularize sync * reformatting folders * testing for sync * do not tolerate vulnerable * remove prints * handle multithreading progress bar * update tests * optimize function and add hub bucket prefix * docstrings and linting * rebase with master * bad rebase * support for gated and training unsupported * merge with master-curated-jumpstart * linting * update types * update * update bootstrap * fix codecov --------- Co-authored-by: qidewenwhen <[email protected]> Co-authored-by: Haotian An <[email protected]> Co-authored-by: liujiaor <[email protected]> Co-authored-by: Jinyoung Lim <[email protected]>
1 parent d820b28 commit a632795

File tree

10 files changed

+115
-99
lines changed

10 files changed

+115
-99
lines changed

src/sagemaker/jumpstart/cache.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3535
MODEL_TYPE_TO_MANIFEST_MAP,
3636
MODEL_TYPE_TO_SPECS_MAP,
37-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3837
)
3938
from sagemaker.jumpstart.exceptions import (
4039
get_wildcard_model_version_msg,
@@ -443,7 +442,7 @@ def _retrieval_function(
443442
formatted_content=utils.get_formatted_manifest(formatted_body),
444443
md5_hash=etag,
445444
)
446-
445+
447446
if data_type in {
448447
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
449448
JumpStartS3FileType.PROPRIETARY_SPECS,

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
S3ObjectLocation,
2323
)
2424
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
25+
from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket
2526
from sagemaker.jumpstart.types import JumpStartModelSpecs
2627

2728

@@ -65,6 +66,10 @@ def generate_file_infos_from_model_specs(
6566
files = []
6667
for dependency in HubContentDependencyType:
6768
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
69+
# Training dependencies will return as None if training is unsupported
70+
if not location or is_gated_bucket(location.bucket):
71+
continue
72+
6873
location_type = "prefix" if location.key.endswith("/") else "object"
6974

7075
if location_type == "prefix":

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

+46-27
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module accessors for the SageMaker JumpStart Public Hub."""
1414
from __future__ import absolute_import
15-
from typing import Dict, Any
15+
from typing import Dict, Any, Optional
1616
from sagemaker import model_uris, script_uris
1717
from sagemaker.jumpstart.curated_hub.types import (
1818
HubContentDependencyType,
@@ -21,7 +21,10 @@
2121
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
2222
from sagemaker.jumpstart.enums import JumpStartScriptScope
2323
from sagemaker.jumpstart.types import JumpStartModelSpecs
24-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24+
from sagemaker.jumpstart.utils import (
25+
get_jumpstart_content_bucket,
26+
get_jumpstart_gated_content_bucket,
27+
)
2528

2629

2730
class PublicModelDataAccessor:
@@ -34,7 +37,11 @@ def __init__(
3437
studio_specs: Dict[str, Dict[str, Any]],
3538
):
3639
self._region = region
37-
self._bucket = get_jumpstart_content_bucket(region)
40+
self._bucket = (
41+
get_jumpstart_gated_content_bucket(region)
42+
if model_specs.gated_bucket
43+
else get_jumpstart_content_bucket(region)
44+
)
3845
self.model_specs = model_specs
3946
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift
4047

@@ -43,47 +50,53 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType):
4350
return getattr(self, dependency_type.value)
4451

4552
@property
46-
def inference_artifact_s3_reference(self):
53+
def inference_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
4754
"""Retrieves s3 reference for model inference artifact"""
4855
return create_s3_object_reference_from_uri(
4956
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
5057
)
5158

5259
@property
53-
def training_artifact_s3_reference(self):
60+
def training_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
5461
"""Retrieves s3 reference for model training artifact"""
62+
if not self.model_specs.training_supported:
63+
return None
5564
return create_s3_object_reference_from_uri(
5665
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
5766
)
5867

5968
@property
60-
def inference_script_s3_reference(self):
69+
def inference_script_s3_reference(self) -> Optional[S3ObjectLocation]:
6170
"""Retrieves s3 reference for model inference script"""
6271
return create_s3_object_reference_from_uri(
6372
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
6473
)
6574

6675
@property
67-
def training_script_s3_reference(self):
76+
def training_script_s3_reference(self) -> Optional[S3ObjectLocation]:
6877
"""Retrieves s3 reference for model training script"""
78+
if not self.model_specs.training_supported:
79+
return None
6980
return create_s3_object_reference_from_uri(
7081
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
7182
)
7283

7384
@property
74-
def default_training_dataset_s3_reference(self):
85+
def default_training_dataset_s3_reference(self) -> S3ObjectLocation:
7586
"""Retrieves s3 reference for s3 directory containing model training datasets"""
76-
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
87+
if not self.model_specs.training_supported:
88+
return None
89+
return S3ObjectLocation(self._get_bucket_name(), self._get_training_dataset_prefix())
7790

7891
@property
79-
def demo_notebook_s3_reference(self):
92+
def demo_notebook_s3_reference(self) -> S3ObjectLocation:
8093
"""Retrieves s3 reference for model demo jupyter notebook"""
8194
framework = self.model_specs.get_framework()
8295
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
8396
return S3ObjectLocation(self._get_bucket_name(), key)
8497

8598
@property
86-
def markdown_s3_reference(self):
99+
def markdown_s3_reference(self) -> S3ObjectLocation:
87100
"""Retrieves s3 reference for model markdown"""
88101
framework = self.model_specs.get_framework()
89102
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
@@ -93,24 +106,30 @@ def _get_bucket_name(self) -> str:
93106
"""Retrieves s3 bucket"""
94107
return self._bucket
95108

96-
def __get_training_dataset_prefix(self) -> str:
109+
def _get_training_dataset_prefix(self) -> Optional[str]:
97110
"""Retrieves training dataset location"""
98-
return self.studio_specs["defaultDataKey"]
111+
return self.studio_specs.get("defaultDataKey")
99112

100-
def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
113+
def _jumpstart_script_s3_uri(self, model_scope: str) -> Optional[str]:
101114
"""Retrieves JumpStart script s3 location"""
102-
return script_uris.retrieve(
103-
region=self._region,
104-
model_id=self.model_specs.model_id,
105-
model_version=self.model_specs.version,
106-
script_scope=model_scope,
107-
)
115+
try:
116+
return script_uris.retrieve(
117+
region=self._region,
118+
model_id=self.model_specs.model_id,
119+
model_version=self.model_specs.version,
120+
script_scope=model_scope,
121+
)
122+
except ValueError:
123+
return None
108124

109-
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
125+
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> Optional[str]:
110126
"""Retrieves JumpStart artifact s3 location"""
111-
return model_uris.retrieve(
112-
region=self._region,
113-
model_id=self.model_specs.model_id,
114-
model_version=self.model_specs.version,
115-
model_scope=model_scope,
116-
)
127+
try:
128+
return model_uris.retrieve(
129+
region=self._region,
130+
model_id=self.model_specs.model_id,
131+
model_version=self.model_specs.version,
132+
model_scope=model_scope,
133+
)
134+
except ValueError:
135+
return None

src/sagemaker/jumpstart/curated_hub/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,11 @@ def generate_default_hub_bucket_name(
133133
return f"sagemaker-hubs-{region}-{account_id}"
134134

135135

136-
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
136+
def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
137137
"""Utiity to help generate an S3 object reference"""
138+
if not s3_uri:
139+
return None
140+
138141
bucket, key = parse_s3_url(s3_uri)
139142

140143
return S3ObjectLocation(
@@ -164,3 +167,8 @@ def create_hub_bucket_if_it_does_not_exist(
164167
)
165168

166169
return bucket_name
170+
171+
172+
def is_gated_bucket(bucket_name: str) -> bool:
173+
"""Returns true if the bucket name is the JumpStart gated bucket."""
174+
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET

src/sagemaker/jumpstart/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -868,12 +868,13 @@ def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str:
868868
"""Returns the Studio Spec file prefix given a model ID and version."""
869869
return f"studio_models/{model_id}/studio_specs_v{model_version}.json"
870870

871+
871872
def extract_info_from_hub_content_arn(
872873
arn: str,
873874
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
874875
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
875876

876-
match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
877+
match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn)
877878
if match:
878879
hub_name = match.group(4)
879880
hub_region = match.group(2)

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def main(sys_args=None):
6565
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")
6666

6767
RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
68-
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
69-
client_sagemaker_pysdk_version
70-
)
7168

7269
user = getpass.getuser()
7370
if user != "root":

tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py

+43
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,46 @@ def test_s3_path_file_generator_with_no_objects(s3_client):
127127

128128
s3_client.list_objects_v2.assert_called_once()
129129
assert response == []
130+
131+
132+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
133+
def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client):
134+
specs = Mock()
135+
specs.model_id = "mock_model_123"
136+
specs.training_supported = False
137+
specs.gated_bucket = False
138+
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
139+
specs.hosting_script_key = "/my/inference/script.py"
140+
141+
response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)
142+
143+
assert response == [
144+
FileInfo(
145+
"jumpstart-cache-prod-us-west-2",
146+
"/my/inference/tarball.tgz",
147+
123456789,
148+
"08-14-1997 00:00:00",
149+
),
150+
FileInfo(
151+
"jumpstart-cache-prod-us-west-2",
152+
"/my/inference/script.py",
153+
123456789,
154+
"08-14-1997 00:00:00",
155+
),
156+
]
157+
158+
159+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
160+
def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):
161+
specs = Mock()
162+
specs.model_id = "mock_model_123"
163+
specs.gated_bucket = True
164+
specs.training_supported = True
165+
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
166+
specs.hosting_script_key = "/my/inference/script.py"
167+
specs.training_prepacked_artifact_key = "/my/training/tarball.tgz"
168+
specs.training_script_key = "/my/training/script.py"
169+
170+
response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)
171+
172+
assert response == []

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -177,29 +177,11 @@ def test_create_hub_bucket_if_it_does_not_exist():
177177
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
178178

179179

180-
def test_generate_default_hub_bucket_name():
181-
mock_sagemaker_session = Mock()
182-
mock_sagemaker_session.account_id.return_value = "123456789123"
183-
mock_sagemaker_session.boto_region_name = "us-east-1"
180+
def test_is_gated_bucket():
181+
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True
184182

185-
assert (
186-
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
187-
== "sagemaker-hubs-us-east-1-123456789123"
188-
)
183+
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True
189184

185+
assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False
190186

191-
def test_create_hub_bucket_if_it_does_not_exist():
192-
mock_sagemaker_session = Mock()
193-
mock_sagemaker_session.account_id.return_value = "123456789123"
194-
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
195-
"Account": "123456789123"
196-
}
197-
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
198-
mock_sagemaker_session.boto_region_name = "us-east-1"
199-
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
200-
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
201-
sagemaker_session=mock_sagemaker_session
202-
)
203-
204-
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
205-
assert created_hub_bucket_name == bucket_name
187+
assert utils.is_gated_bucket("") is False

tests/unit/sagemaker/jumpstart/test_accessors.py

+4-32
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
8383
accessors.JumpStartModelsAccessor.get_model_specs(
8484
region=region, model_id=model_id, version=version
8585
)
86-
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
86+
mock_cache.get_specs.assert_called_once_with(
87+
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
88+
)
8789
mock_cache.get_hub_model.assert_not_called()
8890

8991
accessors.JumpStartModelsAccessor.get_model_specs(
@@ -98,6 +100,7 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
98100
)
99101
)
100102

103+
101104
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
102105
def test_jumpstart_proprietary_models_cache_get(mock_cache):
103106

@@ -138,37 +141,6 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
138141
)
139142

140143

141-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
142-
def test_jumpstart_models_cache_get_model_specs(mock_cache):
143-
mock_cache.get_specs = Mock()
144-
mock_cache.get_hub_model = Mock()
145-
model_id, version = "pytorch-ic-mobilenet-v2", "*"
146-
region = "us-west-2"
147-
148-
accessors.JumpStartModelsAccessor.get_model_specs(
149-
region=region, model_id=model_id, version=version
150-
)
151-
mock_cache.get_specs.assert_called_once_with(
152-
model_id=model_id, version_str=version, model_type=JumpStartModelType.OPEN_WEIGHTS
153-
)
154-
mock_cache.get_hub_model.assert_not_called()
155-
156-
accessors.JumpStartModelsAccessor.get_model_specs(
157-
region=region,
158-
model_id=model_id,
159-
version=version,
160-
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
161-
)
162-
mock_cache.get_hub_model.assert_called_once_with(
163-
hub_model_arn=(
164-
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
165-
)
166-
)
167-
168-
# necessary because accessors is a static module
169-
reload(accessors)
170-
171-
172144
@patch("sagemaker.jumpstart.cache.JumpStartModelsCache")
173145
def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock):
174146

tests/unit/sagemaker/jumpstart/utils.py

-10
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,6 @@ def patched_retrieval_function(
254254
)
255255
)
256256

257-
if datatype == HubContentType.MODEL:
258-
_, _, _, model_name, model_version = id_info.split("/")
259-
return JumpStartCachedContentValue(
260-
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
261-
)
262-
263-
# TODO: Implement
264-
if datatype == HubType.HUB:
265-
return None
266-
267257
raise ValueError(f"Bad value for datatype: {datatype}")
268258

269259

0 commit comments

Comments
 (0)