Skip to content

fix: typo in jumpstart manifest and refine tests #4558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class JumpStartS3FileType(str, Enum):

OPEN_WEIGHT_MANIFEST = "manifest"
OPEN_WEIGHT_SPECS = "specs"
PROPRIETARY_MANIFEST = "proptietary_manifest"
PROPRIETARY_MANIFEST = "proprietary_manifest"
PROPRIETARY_SPECS = "proprietary_specs"


Expand Down
23 changes: 8 additions & 15 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,20 +765,6 @@ def validate_model_id_and_get_type(
ValueError: If the script is not supported by JumpStart.
"""

def _get_model_type(
model_id: str,
open_weights_model_ids: Set[str],
proprietary_model_ids: Set[str],
script: enums.JumpStartScriptScope,
) -> Optional[enums.JumpStartModelType]:
if model_id in open_weights_model_ids:
return enums.JumpStartModelType.OPEN_WEIGHTS
if model_id in proprietary_model_ids:
if script == enums.JumpStartScriptScope.INFERENCE:
return enums.JumpStartModelType.PROPRIETARY
raise ValueError(f"Unsupported script for Marketplace models: {script}")
return None

if model_id in {None, ""}:
return None
if not isinstance(model_id, str):
Expand All @@ -792,12 +778,19 @@ def _get_model_type(
)
open_weight_model_id_set = {model.model_id for model in models_manifest_list}

if model_id in open_weight_model_id_set:
return enums.JumpStartModelType.OPEN_WEIGHTS

proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY
)

proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list}
return _get_model_type(model_id, open_weight_model_id_set, proprietary_model_id_set, script)
if model_id in proprietary_model_id_set:
if script == enums.JumpStartScriptScope.INFERENCE:
return enums.JumpStartModelType.PROPRIETARY
raise ValueError(f"Unsupported script for Proprietary models: {script}")
return None


def get_jumpstart_model_id_version_from_resource_arn(
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6350,6 +6350,12 @@
"py_version": "py3",
},
"training_artifact_key": "pytorch-training/train-pytorch-eqa-bert-base-cased.tar.gz",
"predictor_specs": {
"supported_content_types": ["application/x-image"],
"supported_accept_types": ["application/json;verbose", "application/json"],
"default_content_type": "application/x-image",
"default_accept_type": "application/json",
},
"inference_environment_variables": [
{
"name": "SAGEMAKER_PROGRAM",
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType

from sagemaker.jumpstart.model import JumpStartModel
Expand All @@ -41,6 +45,7 @@
overwrite_dictionary,
get_special_model_spec_for_inference_component_based_endpoint,
get_prototype_manifest,
get_prototype_model_spec,
)
import boto3

Expand Down Expand Up @@ -1365,6 +1370,50 @@ def test_jumpstart_model_session(
assert len(s3_clients) == 1
assert list(s3_clients)[0] == session.s3_client

@mock.patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root",
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root",
},
)
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_model_local_mode(
self,
mock_model_deploy: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
):
mock_get_model_specs.side_effect = get_prototype_model_spec
mock_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
mock_model_deploy.return_value = default_predictor

model_id, _ = "pytorch-eqa-bert-base-cased", "*"

mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge")

model.deploy()

mock_model_deploy.assert_called_once_with(
initial_instance_count=1,
instance_type="ml.p2.xlarge",
tags=[
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
],
wait=True,
endpoint_logging=False,
)


def test_jumpstart_model_requires_model_id():
with pytest.raises(ValueError):
Expand Down
80 changes: 38 additions & 42 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
VulnerableJumpStartModelError,
)
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_prototype_manifest
from mock import MagicMock


Expand Down Expand Up @@ -1178,7 +1178,7 @@ def test_mime_type_enum_from_str():
class TestIsValidModelId(TestCase):
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")
def test_validate_model_id_and_get_type_true(
def test_validate_model_id_and_get_type_open_weights(
self,
mock_get_model_specs: Mock,
mock_get_manifest: Mock,
Expand All @@ -1197,11 +1197,11 @@ def test_validate_model_id_and_get_type_true(
)

with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched):
self.assertTrue(utils.validate_model_id_and_get_type("bee"))
assert utils.validate_model_id_and_get_type("bee") == JumpStartModelType.OPEN_WEIGHTS
mock_get_manifest.assert_called_with(
region=JUMPSTART_DEFAULT_REGION_NAME,
s3_client=mock_s3_client_value,
model_type=JumpStartModelType.PROPRIETARY,
model_type=JumpStartModelType.OPEN_WEIGHTS,
)
mock_get_model_specs.assert_not_called()

Expand All @@ -1215,25 +1215,30 @@ def test_validate_model_id_and_get_type_true(
]

mock_get_model_specs.return_value = Mock(training_supported=True)
self.assertTrue(
self.assertIsNone(
utils.validate_model_id_and_get_type(
"invalid", script=JumpStartScriptScope.TRAINING
)
)
assert (
utils.validate_model_id_and_get_type("bee", script=JumpStartScriptScope.TRAINING)
== JumpStartModelType.OPEN_WEIGHTS
)

mock_get_manifest.assert_called_with(
region=JUMPSTART_DEFAULT_REGION_NAME,
s3_client=mock_s3_client_value,
model_type=JumpStartModelType.PROPRIETARY,
model_type=JumpStartModelType.OPEN_WEIGHTS,
)

@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")
def test_validate_model_id_and_get_type_false(
def test_validate_model_id_and_get_type_invalid(
self, mock_get_model_specs: Mock, mock_get_manifest: Mock
):
mock_get_manifest.return_value = [
Mock(model_id="ay"),
Mock(model_id="bee"),
Mock(model_id="see"),
]
mock_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)

mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
mock_s3_client_value = mock_session_value.s3_client
Expand All @@ -1244,10 +1249,10 @@ def test_validate_model_id_and_get_type_false(

with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched):

self.assertFalse(utils.validate_model_id_and_get_type("dee"))
self.assertFalse(utils.validate_model_id_and_get_type(""))
self.assertFalse(utils.validate_model_id_and_get_type(None))
self.assertFalse(utils.validate_model_id_and_get_type(set()))
self.assertIsNone(utils.validate_model_id_and_get_type("dee"))
self.assertIsNone(utils.validate_model_id_and_get_type(""))
self.assertIsNone(utils.validate_model_id_and_get_type(None))
self.assertIsNone(utils.validate_model_id_and_get_type(set()))

mock_get_manifest.assert_called()

Expand All @@ -1256,53 +1261,44 @@ def test_validate_model_id_and_get_type_false(
mock_get_manifest.reset_mock()
mock_get_model_specs.reset_mock()

mock_get_manifest.return_value = [
Mock(model_id="ay"),
Mock(model_id="bee"),
Mock(model_id="see"),
]
self.assertFalse(
utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING)
assert (
utils.validate_model_id_and_get_type("ai21-summarization")
== JumpStartModelType.PROPRIETARY
)
self.assertIsNone(utils.validate_model_id_and_get_type("ai21-summarization-2"))

mock_get_manifest.assert_called_with(
region=JUMPSTART_DEFAULT_REGION_NAME,
s3_client=mock_s3_client_value,
model_type=JumpStartModelType.PROPRIETARY,
)

mock_get_manifest.reset_mock()

self.assertFalse(
self.assertIsNone(
utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING)
)
self.assertFalse(
self.assertIsNone(
utils.validate_model_id_and_get_type("", script=JumpStartScriptScope.TRAINING)
)
self.assertFalse(
self.assertIsNone(
utils.validate_model_id_and_get_type(None, script=JumpStartScriptScope.TRAINING)
)
self.assertFalse(
self.assertIsNone(
utils.validate_model_id_and_get_type(set(), script=JumpStartScriptScope.TRAINING)
)

mock_get_model_specs.assert_not_called()
assert (
utils.validate_model_id_and_get_type("pytorch-eqa-bert-base-cased")
== JumpStartModelType.OPEN_WEIGHTS
)
mock_get_manifest.assert_called_with(
region=JUMPSTART_DEFAULT_REGION_NAME,
s3_client=mock_s3_client_value,
model_type=JumpStartModelType.PROPRIETARY,
model_type=JumpStartModelType.OPEN_WEIGHTS,
)

mock_get_manifest.reset_mock()
mock_get_model_specs.reset_mock()

mock_get_model_specs.return_value = Mock(training_supported=False)
self.assertTrue(
utils.validate_model_id_and_get_type("ay", script=JumpStartScriptScope.TRAINING)
)
mock_get_manifest.assert_called_with(
region=JUMPSTART_DEFAULT_REGION_NAME,
s3_client=mock_s3_client_value,
model_type=JumpStartModelType.PROPRIETARY,
with pytest.raises(ValueError):
utils.validate_model_id_and_get_type(
"ai21-summarization", script=JumpStartScriptScope.TRAINING
)


Expand Down