Skip to content

Commit 347b599

Browse files
Captainialiujiaorr
authored andcommitted
feat: support JumpStart proprietary models (aws#4467)
* feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <[email protected]>
1 parent f7677d8 commit 347b599

File tree

5 files changed

+3
-2
lines changed

5 files changed

+3
-2
lines changed

src/sagemaker/jumpstart/cache.py

-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ def _retrieval_function(
429429
"""
430430

431431
data_type, id_info = key.data_type, key.id_info
432-
433432
if data_type in {
434433
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
435434
JumpStartS3FileType.PROPRIETARY_MANIFEST,

src/sagemaker/jumpstart/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
14341434
"model_version",
14351435
"model_type",
14361436
"hub_arn",
1437+
"model_type",
14371438
"region",
14381439
"tolerate_deprecated_model",
14391440
"tolerate_vulnerable_model",

tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode
126126
model_type=JumpStartModelType.OPEN_WEIGHTS,
127127
hub_arn=None,
128128
s3_client=mock_client,
129+
model_type=JumpStartModelType.OPEN_WEIGHTS,
129130
)
130131

131132
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/jumpstart/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
HubContentType,
3232
)
3333
from sagemaker.jumpstart.enums import JumpStartModelType
34-
3534
from sagemaker.jumpstart.utils import get_formatted_manifest
3635
from tests.unit.sagemaker.jumpstart.constants import (
3736
PROTOTYPICAL_MODEL_SPECS_DICT,

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri(
5454
s3_client=mock_client,
5555
model_type=JumpStartModelType.OPEN_WEIGHTS,
5656
hub_arn=None,
57+
model_type=JumpStartModelType.OPEN_WEIGHTS,
5758
)
5859
patched_verify_model_region_and_return_specs.assert_called_once()
5960

0 commit comments

Comments
 (0)