Skip to content

Commit c1a18bb

Browse files
committed
add more tests
1 parent 50d1925 commit c1a18bb

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6350,6 +6350,12 @@
63506350
"py_version": "py3",
63516351
},
63526352
"training_artifact_key": "pytorch-training/train-pytorch-eqa-bert-base-cased.tar.gz",
6353+
"predictor_specs": {
6354+
"supported_content_types": ["application/x-image"],
6355+
"supported_accept_types": ["application/json;verbose", "application/json"],
6356+
"default_content_type": "application/x-image",
6357+
"default_accept_type": "application/json",
6358+
},
63536359
"inference_environment_variables": [
63546360
{
63556361
"name": "SAGEMAKER_PROGRAM",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2626
JUMPSTART_DEFAULT_REGION_NAME,
2727
)
28+
from sagemaker.jumpstart.constants import (
29+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
30+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
31+
)
2832
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType
2933

3034
from sagemaker.jumpstart.model import JumpStartModel
@@ -41,6 +45,7 @@
4145
overwrite_dictionary,
4246
get_special_model_spec_for_inference_component_based_endpoint,
4347
get_prototype_manifest,
48+
get_prototype_model_spec,
4449
)
4550
import boto3
4651

@@ -1365,6 +1370,50 @@ def test_jumpstart_model_session(
13651370
assert len(s3_clients) == 1
13661371
assert list(s3_clients)[0] == session.s3_client
13671372

1373+
@mock.patch.dict(
1374+
"sagemaker.jumpstart.cache.os.environ",
1375+
{
1376+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
1377+
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
1378+
},
1379+
)
1380+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
1381+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1382+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1383+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
1384+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1385+
def test_model_local_mode(
1386+
self,
1387+
mock_model_deploy: mock.Mock,
1388+
mock_get_model_specs: mock.Mock,
1389+
mock_session: mock.Mock,
1390+
mock_get_manifest: mock.Mock,
1391+
):
1392+
mock_get_model_specs.side_effect = get_prototype_model_spec
1393+
mock_get_manifest.side_effect = (
1394+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
1395+
)
1396+
mock_model_deploy.return_value = default_predictor
1397+
1398+
model_id, _ = "pytorch-eqa-bert-base-cased", "*"
1399+
1400+
mock_session.return_value = sagemaker_session
1401+
1402+
model = JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge")
1403+
1404+
model.deploy()
1405+
1406+
mock_model_deploy.assert_called_once_with(
1407+
initial_instance_count=1,
1408+
instance_type="ml.p2.xlarge",
1409+
tags=[
1410+
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
1411+
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
1412+
],
1413+
wait=True,
1414+
endpoint_logging=False,
1415+
)
1416+
13681417

13691418
def test_jumpstart_model_requires_model_id():
13701419
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)