Skip to content

Commit 208ef86

Browse files
authored
fix: typo in jumpstart manifest and refine tests (#4558)
* fix: type in manifest and refinement of tests * add more tests * add back a few tests * black * added one more test * add test
1 parent 3c04759 commit 208ef86

File tree

5 files changed

+102
-58
lines changed

5 files changed

+102
-58
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class JumpStartS3FileType(str, Enum):
105105

106106
OPEN_WEIGHT_MANIFEST = "manifest"
107107
OPEN_WEIGHT_SPECS = "specs"
108-
PROPRIETARY_MANIFEST = "proptietary_manifest"
108+
PROPRIETARY_MANIFEST = "proprietary_manifest"
109109
PROPRIETARY_SPECS = "proprietary_specs"
110110

111111

src/sagemaker/jumpstart/utils.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -765,20 +765,6 @@ def validate_model_id_and_get_type(
765765
ValueError: If the script is not supported by JumpStart.
766766
"""
767767

768-
def _get_model_type(
769-
model_id: str,
770-
open_weights_model_ids: Set[str],
771-
proprietary_model_ids: Set[str],
772-
script: enums.JumpStartScriptScope,
773-
) -> Optional[enums.JumpStartModelType]:
774-
if model_id in open_weights_model_ids:
775-
return enums.JumpStartModelType.OPEN_WEIGHTS
776-
if model_id in proprietary_model_ids:
777-
if script == enums.JumpStartScriptScope.INFERENCE:
778-
return enums.JumpStartModelType.PROPRIETARY
779-
raise ValueError(f"Unsupported script for Marketplace models: {script}")
780-
return None
781-
782768
if model_id in {None, ""}:
783769
return None
784770
if not isinstance(model_id, str):
@@ -792,12 +778,19 @@ def _get_model_type(
792778
)
793779
open_weight_model_id_set = {model.model_id for model in models_manifest_list}
794780

781+
if model_id in open_weight_model_id_set:
782+
return enums.JumpStartModelType.OPEN_WEIGHTS
783+
795784
proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
796785
region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY
797786
)
798787

799788
proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list}
800-
return _get_model_type(model_id, open_weight_model_id_set, proprietary_model_id_set, script)
789+
if model_id in proprietary_model_id_set:
790+
if script == enums.JumpStartScriptScope.INFERENCE:
791+
return enums.JumpStartModelType.PROPRIETARY
792+
raise ValueError(f"Unsupported script for Proprietary models: {script}")
793+
return None
801794

802795

803796
def get_jumpstart_model_id_version_from_resource_arn(

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6350,6 +6350,12 @@
63506350
"py_version": "py3",
63516351
},
63526352
"training_artifact_key": "pytorch-training/train-pytorch-eqa-bert-base-cased.tar.gz",
6353+
"predictor_specs": {
6354+
"supported_content_types": ["application/x-image"],
6355+
"supported_accept_types": ["application/json;verbose", "application/json"],
6356+
"default_content_type": "application/x-image",
6357+
"default_accept_type": "application/json",
6358+
},
63536359
"inference_environment_variables": [
63546360
{
63556361
"name": "SAGEMAKER_PROGRAM",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2626
JUMPSTART_DEFAULT_REGION_NAME,
2727
)
28+
from sagemaker.jumpstart.constants import (
29+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
30+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
31+
)
2832
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType
2933

3034
from sagemaker.jumpstart.model import JumpStartModel
@@ -41,6 +45,7 @@
4145
overwrite_dictionary,
4246
get_special_model_spec_for_inference_component_based_endpoint,
4347
get_prototype_manifest,
48+
get_prototype_model_spec,
4449
)
4550
import boto3
4651

@@ -1365,6 +1370,50 @@ def test_jumpstart_model_session(
13651370
assert len(s3_clients) == 1
13661371
assert list(s3_clients)[0] == session.s3_client
13671372

1373+
@mock.patch.dict(
1374+
"sagemaker.jumpstart.cache.os.environ",
1375+
{
1376+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
1377+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
1378+
},
1379+
)
1380+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
1381+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1382+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1383+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
1384+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1385+
def test_model_local_mode(
1386+
self,
1387+
mock_model_deploy: mock.Mock,
1388+
mock_get_model_specs: mock.Mock,
1389+
mock_session: mock.Mock,
1390+
mock_get_manifest: mock.Mock,
1391+
):
1392+
mock_get_model_specs.side_effect = get_prototype_model_spec
1393+
mock_get_manifest.side_effect = (
1394+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
1395+
)
1396+
mock_model_deploy.return_value = default_predictor
1397+
1398+
model_id, _ = "pytorch-eqa-bert-base-cased", "*"
1399+
1400+
mock_session.return_value = sagemaker_session
1401+
1402+
model = JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge")
1403+
1404+
model.deploy()
1405+
1406+
mock_model_deploy.assert_called_once_with(
1407+
initial_instance_count=1,
1408+
instance_type="ml.p2.xlarge",
1409+
tags=[
1410+
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
1411+
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
1412+
],
1413+
wait=True,
1414+
endpoint_logging=False,
1415+
)
1416+
13681417

13691418
def test_jumpstart_model_requires_model_id():
13701419
with pytest.raises(ValueError):

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
VulnerableJumpStartModelError,
4141
)
4242
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
43-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
43+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_prototype_manifest
4444
from mock import MagicMock
4545

4646

@@ -1178,7 +1178,7 @@ def test_mime_type_enum_from_str():
11781178
class TestIsValidModelId(TestCase):
11791179
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
11801180
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")
1181-
def test_validate_model_id_and_get_type_true(
1181+
def test_validate_model_id_and_get_type_open_weights(
11821182
self,
11831183
mock_get_model_specs: Mock,
11841184
mock_get_manifest: Mock,
@@ -1197,11 +1197,11 @@ def test_validate_model_id_and_get_type_true(
11971197
)
11981198

11991199
with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched):
1200-
self.assertTrue(utils.validate_model_id_and_get_type("bee"))
1200+
assert utils.validate_model_id_and_get_type("bee") == JumpStartModelType.OPEN_WEIGHTS
12011201
mock_get_manifest.assert_called_with(
12021202
region=JUMPSTART_DEFAULT_REGION_NAME,
12031203
s3_client=mock_s3_client_value,
1204-
model_type=JumpStartModelType.PROPRIETARY,
1204+
model_type=JumpStartModelType.OPEN_WEIGHTS,
12051205
)
12061206
mock_get_model_specs.assert_not_called()
12071207

@@ -1215,25 +1215,30 @@ def test_validate_model_id_and_get_type_true(
12151215
]
12161216

12171217
mock_get_model_specs.return_value = Mock(training_supported=True)
1218-
self.assertTrue(
1218+
self.assertIsNone(
1219+
utils.validate_model_id_and_get_type(
1220+
"invalid", script=JumpStartScriptScope.TRAINING
1221+
)
1222+
)
1223+
assert (
12191224
utils.validate_model_id_and_get_type("bee", script=JumpStartScriptScope.TRAINING)
1225+
== JumpStartModelType.OPEN_WEIGHTS
12201226
)
1227+
12211228
mock_get_manifest.assert_called_with(
12221229
region=JUMPSTART_DEFAULT_REGION_NAME,
12231230
s3_client=mock_s3_client_value,
1224-
model_type=JumpStartModelType.PROPRIETARY,
1231+
model_type=JumpStartModelType.OPEN_WEIGHTS,
12251232
)
12261233

12271234
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12281235
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")
1229-
def test_validate_model_id_and_get_type_false(
1236+
def test_validate_model_id_and_get_type_invalid(
12301237
self, mock_get_model_specs: Mock, mock_get_manifest: Mock
12311238
):
1232-
mock_get_manifest.return_value = [
1233-
Mock(model_id="ay"),
1234-
Mock(model_id="bee"),
1235-
Mock(model_id="see"),
1236-
]
1239+
mock_get_manifest.side_effect = (
1240+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
1241+
)
12371242

12381243
mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
12391244
mock_s3_client_value = mock_session_value.s3_client
@@ -1244,10 +1249,10 @@ def test_validate_model_id_and_get_type_false(
12441249

12451250
with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched):
12461251

1247-
self.assertFalse(utils.validate_model_id_and_get_type("dee"))
1248-
self.assertFalse(utils.validate_model_id_and_get_type(""))
1249-
self.assertFalse(utils.validate_model_id_and_get_type(None))
1250-
self.assertFalse(utils.validate_model_id_and_get_type(set()))
1252+
self.assertIsNone(utils.validate_model_id_and_get_type("dee"))
1253+
self.assertIsNone(utils.validate_model_id_and_get_type(""))
1254+
self.assertIsNone(utils.validate_model_id_and_get_type(None))
1255+
self.assertIsNone(utils.validate_model_id_and_get_type(set()))
12511256

12521257
mock_get_manifest.assert_called()
12531258

@@ -1256,53 +1261,44 @@ def test_validate_model_id_and_get_type_false(
12561261
mock_get_manifest.reset_mock()
12571262
mock_get_model_specs.reset_mock()
12581263

1259-
mock_get_manifest.return_value = [
1260-
Mock(model_id="ay"),
1261-
Mock(model_id="bee"),
1262-
Mock(model_id="see"),
1263-
]
1264-
self.assertFalse(
1265-
utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING)
1264+
assert (
1265+
utils.validate_model_id_and_get_type("ai21-summarization")
1266+
== JumpStartModelType.PROPRIETARY
12661267
)
1268+
self.assertIsNone(utils.validate_model_id_and_get_type("ai21-summarization-2"))
1269+
12671270
mock_get_manifest.assert_called_with(
12681271
region=JUMPSTART_DEFAULT_REGION_NAME,
12691272
s3_client=mock_s3_client_value,
12701273
model_type=JumpStartModelType.PROPRIETARY,
12711274
)
12721275

1273-
mock_get_manifest.reset_mock()
1274-
1275-
self.assertFalse(
1276+
self.assertIsNone(
12761277
utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING)
12771278
)
1278-
self.assertFalse(
1279+
self.assertIsNone(
12791280
utils.validate_model_id_and_get_type("", script=JumpStartScriptScope.TRAINING)
12801281
)
1281-
self.assertFalse(
1282+
self.assertIsNone(
12821283
utils.validate_model_id_and_get_type(None, script=JumpStartScriptScope.TRAINING)
12831284
)
1284-
self.assertFalse(
1285+
self.assertIsNone(
12851286
utils.validate_model_id_and_get_type(set(), script=JumpStartScriptScope.TRAINING)
12861287
)
12871288

1288-
mock_get_model_specs.assert_not_called()
1289+
assert (
1290+
utils.validate_model_id_and_get_type("pytorch-eqa-bert-base-cased")
1291+
== JumpStartModelType.OPEN_WEIGHTS
1292+
)
12891293
mock_get_manifest.assert_called_with(
12901294
region=JUMPSTART_DEFAULT_REGION_NAME,
12911295
s3_client=mock_s3_client_value,
1292-
model_type=JumpStartModelType.PROPRIETARY,
1296+
model_type=JumpStartModelType.OPEN_WEIGHTS,
12931297
)
12941298

1295-
mock_get_manifest.reset_mock()
1296-
mock_get_model_specs.reset_mock()
1297-
1298-
mock_get_model_specs.return_value = Mock(training_supported=False)
1299-
self.assertTrue(
1300-
utils.validate_model_id_and_get_type("ay", script=JumpStartScriptScope.TRAINING)
1301-
)
1302-
mock_get_manifest.assert_called_with(
1303-
region=JUMPSTART_DEFAULT_REGION_NAME,
1304-
s3_client=mock_s3_client_value,
1305-
model_type=JumpStartModelType.PROPRIETARY,
1299+
with pytest.raises(ValueError):
1300+
utils.validate_model_id_and_get_type(
1301+
"ai21-summarization", script=JumpStartScriptScope.TRAINING
13061302
)
13071303

13081304

0 commit comments

Comments
 (0)