Skip to content

Commit d32067d

Browse files
Captainialiujiaorr
authored andcommitted
feat: support JumpStart proprietary models (aws#4467)
* feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <[email protected]>
1 parent d0ce1a1 commit d32067d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1963
-490
lines changed

doc/doc_utils/jumpstart_doc_utils.py

+68-15
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ class Frameworks(str, Enum):
7474

7575
JUMPSTART_REGION = "eu-west-2"
7676
SDK_MANIFEST_FILE = "models_manifest.json"
77+
PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json"
7778
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
7879
JUMPSTART_REGION, JUMPSTART_REGION
7980
)
81+
PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com"
82+
8083
TASK_MAP = {
8184
Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION,
8285
Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING,
@@ -152,18 +155,26 @@ class Frameworks(str, Enum):
152155
}
153156

154157

155-
def get_jumpstart_sdk_manifest():
156-
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE)
158+
def get_public_s3_json_object(url):
157159
with request.urlopen(url) as f:
158160
models_manifest = f.read().decode("utf-8")
159161
return json.loads(models_manifest)
160162

161163

162-
def get_jumpstart_sdk_spec(key):
163-
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key)
164-
with request.urlopen(url) as f:
165-
model_spec = f.read().decode("utf-8")
166-
return json.loads(model_spec)
164+
def get_jumpstart_sdk_manifest():
165+
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}")
166+
167+
168+
def get_proprietary_sdk_manifest():
169+
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}")
170+
171+
172+
def get_jumpstart_sdk_spec(s3_key: str):
173+
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}")
174+
175+
176+
def get_proprietary_sdk_spec(s3_key: str):
177+
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}")
167178

168179

169180
def get_model_task(id):
@@ -196,6 +207,45 @@ def get_model_source(url):
196207
return "Source"
197208

198209

210+
def create_proprietary_model_table():
211+
proprietary_content_intro = []
212+
proprietary_content_intro.append("\n")
213+
proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n")
214+
proprietary_content_intro.append(" :widths: 50 20 20 20 20\n")
215+
proprietary_content_intro.append(" :header-rows: 1\n")
216+
proprietary_content_intro.append(" :class: datatable\n")
217+
proprietary_content_intro.append("\n")
218+
proprietary_content_intro.append(" * - Model ID\n")
219+
proprietary_content_intro.append(" - Fine Tunable?\n")
220+
proprietary_content_intro.append(" - Supported Version\n")
221+
proprietary_content_intro.append(" - Min SDK Version\n")
222+
proprietary_content_intro.append(" - Source\n")
223+
224+
sdk_manifest = get_proprietary_sdk_manifest()
225+
sdk_manifest_top_versions_for_models = {}
226+
227+
for model in sdk_manifest:
228+
if model["model_id"] not in sdk_manifest_top_versions_for_models:
229+
sdk_manifest_top_versions_for_models[model["model_id"]] = model
230+
else:
231+
if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str(
232+
model["version"]
233+
):
234+
sdk_manifest_top_versions_for_models[model["model_id"]] = model
235+
236+
proprietary_content_entries = []
237+
for model in sdk_manifest_top_versions_for_models.values():
238+
model_spec = get_proprietary_sdk_spec(model["spec_key"])
239+
proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
240+
proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training
241+
proprietary_content_entries.append(" - {}\n".format(model["version"]))
242+
proprietary_content_entries.append(" - {}\n".format(model["min_version"]))
243+
proprietary_content_entries.append(
244+
" - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url"))
245+
)
246+
return proprietary_content_intro + proprietary_content_entries + ["\n"]
247+
248+
199249
def create_jumpstart_model_table():
200250
sdk_manifest = get_jumpstart_sdk_manifest()
201251
sdk_manifest_top_versions_for_models = {}
@@ -249,19 +299,19 @@ def create_jumpstart_model_table():
249299
file_content_intro.append(" - Source\n")
250300

251301
dynamic_table_files = []
252-
file_content_entries = []
302+
open_weight_content_entries = []
253303

254304
for model in sdk_manifest_top_versions_for_models.values():
255305
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
256306
model_task = get_model_task(model_spec["model_id"])
257307
string_model_task = get_string_model_task(model_spec["model_id"])
258308
model_source = get_model_source(model_spec["url"])
259-
file_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
260-
file_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
261-
file_content_entries.append(" - {}\n".format(model["version"]))
262-
file_content_entries.append(" - {}\n".format(model["min_version"]))
263-
file_content_entries.append(" - {}\n".format(model_task))
264-
file_content_entries.append(
309+
open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
310+
open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
311+
open_weight_content_entries.append(" - {}\n".format(model["version"]))
312+
open_weight_content_entries.append(" - {}\n".format(model["min_version"]))
313+
open_weight_content_entries.append(" - {}\n".format(model_task))
314+
open_weight_content_entries.append(
265315
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
266316
)
267317

@@ -299,7 +349,10 @@ def create_jumpstart_model_table():
299349
f.writelines(file_content_single_entry)
300350
f.close()
301351

352+
proprietary_content_entries = create_proprietary_model_table()
353+
302354
f = open("doc_utils/pretrainedmodels.rst", "a")
303355
f.writelines(file_content_intro)
304-
f.writelines(file_content_entries)
356+
f.writelines(open_weight_content_entries)
357+
f.writelines(proprietary_content_entries)
305358
f.close()

src/sagemaker/accept_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.jumpstart.enums import JumpStartModelType
1920
from sagemaker.session import Session
2021

2122

@@ -75,6 +76,7 @@ def retrieve_default(
7576
tolerate_vulnerable_model: bool = False,
7677
tolerate_deprecated_model: bool = False,
7778
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
79+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
7880
) -> str:
7981
"""Retrieves the default accept type for the model matching the given arguments.
8082
@@ -114,4 +116,5 @@ def retrieve_default(
114116
tolerate_vulnerable_model,
115117
tolerate_deprecated_model,
116118
sagemaker_session=sagemaker_session,
119+
model_type=model_type,
117120
)

src/sagemaker/base_predictor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@
5858
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5959

6060
from sagemaker.lineage.context import EndpointContext
61-
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
61+
from sagemaker.compute_resource_requirements.resource_requirements import (
62+
ResourceRequirements,
63+
)
6264

6365
LOGGER = logging.getLogger("sagemaker")
6466

src/sagemaker/content_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
19+
from sagemaker.jumpstart.enums import JumpStartModelType
1920
from sagemaker.session import Session
2021

2122

@@ -75,6 +76,7 @@ def retrieve_default(
7576
tolerate_vulnerable_model: bool = False,
7677
tolerate_deprecated_model: bool = False,
7778
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
79+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
7880
) -> str:
7981
"""Retrieves the default content type for the model matching the given arguments.
8082
@@ -114,6 +116,7 @@ def retrieve_default(
114116
tolerate_vulnerable_model,
115117
tolerate_deprecated_model,
116118
sagemaker_session=sagemaker_session,
119+
model_type=model_type,
117120
)
118121

119122

src/sagemaker/deserializers.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
3737
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
38+
from sagemaker.jumpstart.enums import JumpStartModelType
3839
from sagemaker.session import Session
3940

4041

@@ -95,6 +96,7 @@ def retrieve_default(
9596
tolerate_vulnerable_model: bool = False,
9697
tolerate_deprecated_model: bool = False,
9798
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
99+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
98100
) -> BaseDeserializer:
99101
"""Retrieves the default deserializer for the model matching the given arguments.
100102
@@ -135,4 +137,5 @@ def retrieve_default(
135137
tolerate_vulnerable_model,
136138
tolerate_deprecated_model,
137139
sagemaker_session=sagemaker_session,
140+
model_type=model_type,
138141
)

src/sagemaker/instance_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23+
from sagemaker.jumpstart.enums import JumpStartModelType
2324
from sagemaker.session import Session
2425

2526
logger = logging.getLogger(__name__)
@@ -34,6 +35,7 @@ def retrieve_default(
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3637
training_instance_type: Optional[str] = None,
38+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3739
) -> str:
3840
"""Retrieves the default instance type for the model matching the given arguments.
3941
@@ -85,6 +87,7 @@ def retrieve_default(
8587
tolerate_deprecated_model,
8688
sagemaker_session=sagemaker_session,
8789
training_instance_type=training_instance_type,
90+
model_type=model_type,
8891
)
8992

9093

src/sagemaker/jumpstart/accessors.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sagemaker.deprecations import deprecated
2020
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
21+
from sagemaker.jumpstart.enums import JumpStartModelType
2122
from sagemaker.jumpstart import cache
2223
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2324

@@ -197,7 +198,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
197198

198199
@staticmethod
199200
def _get_manifest(
200-
region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None
201+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
202+
s3_client: Optional[boto3.client] = None,
203+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
201204
) -> List[JumpStartModelHeader]:
202205
"""Return entire JumpStart models manifest.
203206
@@ -215,13 +218,19 @@ def _get_manifest(
215218
additional_kwargs.update({"s3_client": s3_client})
216219

217220
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
218-
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region
221+
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs},
222+
region,
219223
)
220224
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
221-
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
225+
return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore
222226

223227
@staticmethod
224-
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
228+
def get_model_header(
229+
region: str,
230+
model_id: str,
231+
version: str,
232+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
233+
) -> JumpStartModelHeader:
225234
"""Returns model header from JumpStart models cache.
226235
227236
Args:
@@ -234,12 +243,18 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
234243
)
235244
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
236245
return JumpStartModelsAccessor._cache.get_header( # type: ignore
237-
model_id=model_id, semantic_version_str=version
246+
model_id=model_id,
247+
semantic_version_str=version,
248+
model_type=model_type,
238249
)
239250

240251
@staticmethod
241252
def get_model_specs(
242-
region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None
253+
region: str,
254+
model_id: str,
255+
version: str,
256+
s3_client: Optional[boto3.client] = None,
257+
model_type=JumpStartModelType.OPEN_WEIGHTS,
243258
) -> JumpStartModelSpecs:
244259
"""Returns model specs from JumpStart models cache.
245260
@@ -260,7 +275,7 @@ def get_model_specs(
260275
)
261276
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
262277
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
263-
model_id=model_id, semantic_version_str=version
278+
model_id=model_id, version_str=version, model_type=model_type
264279
)
265280

266281
@staticmethod

src/sagemaker/jumpstart/artifacts/instance_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from sagemaker.jumpstart.enums import (
2424
JumpStartScriptScope,
25+
JumpStartModelType,
2526
)
2627
from sagemaker.jumpstart.utils import (
2728
verify_model_region_and_return_specs,
@@ -38,6 +39,7 @@ def _retrieve_default_instance_type(
3839
tolerate_deprecated_model: bool = False,
3940
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4041
training_instance_type: Optional[str] = None,
42+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4143
) -> str:
4244
"""Retrieves the default instance type for the model.
4345
@@ -84,6 +86,7 @@ def _retrieve_default_instance_type(
8486
region=region,
8587
tolerate_vulnerable_model=tolerate_vulnerable_model,
8688
tolerate_deprecated_model=tolerate_deprecated_model,
89+
model_type=model_type,
8790
sagemaker_session=sagemaker_session,
8891
)
8992

src/sagemaker/jumpstart/artifacts/kwargs.py

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from sagemaker.jumpstart.enums import (
2424
JumpStartScriptScope,
25+
JumpStartModelType,
2526
)
2627
from sagemaker.jumpstart.utils import (
2728
verify_model_region_and_return_specs,
@@ -35,6 +36,7 @@ def _retrieve_model_init_kwargs(
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
3738
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3840
) -> dict:
3941
"""Retrieves kwargs for `Model`.
4042
@@ -71,6 +73,7 @@ def _retrieve_model_init_kwargs(
7173
tolerate_vulnerable_model=tolerate_vulnerable_model,
7274
tolerate_deprecated_model=tolerate_deprecated_model,
7375
sagemaker_session=sagemaker_session,
76+
model_type=model_type,
7477
)
7578

7679
kwargs = deepcopy(model_specs.model_kwargs)
@@ -89,6 +92,7 @@ def _retrieve_model_deploy_kwargs(
8992
tolerate_vulnerable_model: bool = False,
9093
tolerate_deprecated_model: bool = False,
9194
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
95+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
9296
) -> dict:
9397
"""Retrieves kwargs for `Model.deploy`.
9498
@@ -128,6 +132,7 @@ def _retrieve_model_deploy_kwargs(
128132
tolerate_vulnerable_model=tolerate_vulnerable_model,
129133
tolerate_deprecated_model=tolerate_deprecated_model,
130134
sagemaker_session=sagemaker_session,
135+
model_type=model_type,
131136
)
132137

133138
if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None:

src/sagemaker/jumpstart/artifacts/model_packages.py

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from sagemaker.jumpstart.enums import (
2424
JumpStartScriptScope,
25+
JumpStartModelType,
2526
)
2627
from sagemaker.session import Session
2728

@@ -35,6 +36,7 @@ def _retrieve_model_package_arn(
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
3738
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3840
) -> Optional[str]:
3941
"""Retrieves associated model pacakge arn for the model.
4042
@@ -74,6 +76,7 @@ def _retrieve_model_package_arn(
7476
tolerate_vulnerable_model=tolerate_vulnerable_model,
7577
tolerate_deprecated_model=tolerate_deprecated_model,
7678
sagemaker_session=sagemaker_session,
79+
model_type=model_type,
7780
)
7881

7982
if scope == JumpStartScriptScope.INFERENCE:

0 commit comments

Comments
 (0)