Skip to content

Commit ac1b7dd

Browse files
Captainialiujiaorr
authored andcommitted
feat: support JumpStart proprietary models (aws#4467)
* feat: add proprietary manifest/specs parsing add unittests for test_cache small refactoring address comments and more unittests fix linting and fix more tests fix: pylint feat: JumpStartModel class for prop models * remove unused imports and fix docstyle * fix: remove unused args * fix: remove unused args * fix: more unused vars * fix: slow tests * fix: unittests * added more tests to cover some lines * remove estimator warn check * chore: address comments re performance * fix: address comments * complete list experience and other fixes * fix: pylint * add doc utils and fix pylint * fix: docstyle * fix: doc * fix: default payloads * fix: doc and tags and enums * fix: jumpstart doc * rename to open_weights and fix filtering * update filter name * doc update * fix: black * rename to proprietary model and fix unittests * address comments * fix: docstyle and flake8 * address more comments and fix doc * put back doc utils for future refactoring * add prop model title in doc * doc update --------- Co-authored-by: liujiaor <[email protected]>
1 parent 1c8ca5e commit ac1b7dd

File tree

6 files changed

+8
-4
lines changed

6 files changed

+8
-4
lines changed

src/sagemaker/jumpstart/cache.py

-2
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,6 @@ def _retrieval_function(
420420
"""
421421

422422
data_type, id_info = key.data_type, key.id_info
423-
424423
if data_type in {
425424
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
426425
JumpStartS3FileType.PROPRIETARY_MANIFEST,
@@ -434,7 +433,6 @@ def _retrieval_function(
434433
formatted_content=utils.get_formatted_manifest(formatted_body),
435434
md5_hash=etag,
436435
)
437-
438436
if data_type in {
439437
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
440438
JumpStartS3FileType.PROPRIETARY_SPECS,

src/sagemaker/jumpstart/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
13991399
"model_version",
14001400
"model_type",
14011401
"hub_arn",
1402+
"model_type",
14021403
"region",
14031404
"tolerate_deprecated_model",
14041405
"tolerate_vulnerable_model",

tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode
126126
model_type=JumpStartModelType.OPEN_WEIGHTS,
127127
hub_arn=None,
128128
s3_client=mock_client,
129+
model_type=JumpStartModelType.OPEN_WEIGHTS,
129130
)
130131

131132
patched_get_model_specs.reset_mock()

tests/unit/sagemaker/jumpstart/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
HubContentType,
3232
)
3333
from sagemaker.jumpstart.enums import JumpStartModelType
34-
3534
from sagemaker.jumpstart.utils import get_formatted_manifest
3635
from tests.unit.sagemaker.jumpstart.constants import (
3736
PROTOTYPICAL_MODEL_SPECS_DICT,

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode
110110
}
111111

112112

113+
@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type")
113114
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
114-
def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
115+
def test_jumpstart_no_supported_resource_requirements(
116+
patched_get_model_specs, patched_validate_model_id_and_get_type
117+
):
118+
115119
patched_get_model_specs.side_effect = get_special_model_spec
116120
region = "us-west-2"
117121
mock_client = boto3.client("s3")

tests/unit/sagemaker/script_uris/jumpstart/test_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri(
5454
s3_client=mock_client,
5555
model_type=JumpStartModelType.OPEN_WEIGHTS,
5656
hub_arn=None,
57+
model_type=JumpStartModelType.OPEN_WEIGHTS,
5758
)
5859
patched_verify_model_region_and_return_specs.assert_called_once()
5960

0 commit comments

Comments
 (0)