Skip to content

JumpStart Gated Model Support in ModelBuilder Local Modes #4567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 11, 2024
18 changes: 18 additions & 0 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,11 @@ def _build_for_jumpstart(self):

logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)

if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError(
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
)

if "djl-inference" in image_uri:
logger.info("Building for DJL JumpStart Model ID...")
self.model_server = ModelServer.DJL_SERVING
Expand All @@ -469,3 +474,16 @@ def _build_for_jumpstart(self):
)

return self.pysdk_model

def _is_gated_model(self, model) -> bool:
"""Determine if ``this`` Model is Gated

Args:
model (Model): Jumpstart Model
Returns:
bool: ``True`` if ``this`` Model is Gated
"""
s3_uri = model.model_data
if isinstance(s3_uri, dict):
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
return "private" in s3_uri
43 changes: 43 additions & 0 deletions tests/integ/sagemaker/serve/test_serve_js_happy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest

from sagemaker.serve import Mode
from sagemaker.serve.builder.model_builder import ModelBuilder
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from tests.integ.sagemaker.serve.constants import (
Expand All @@ -32,6 +33,7 @@
{"generated_text": "Hello, I'm a language model, and I'm here to help you with your English."}
]
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
JS_GATED_MODEL_ID = "huggingface-llm-zephyr-7b-gemma"
ROLE_NAME = "SageMakerRole"


Expand All @@ -46,6 +48,17 @@ def happy_model_builder(sagemaker_session):
)


@pytest.fixture
def happy_model_builder_gated_model(sagemaker_session):
iam_client = sagemaker_session.boto_session.client("iam")
return ModelBuilder(
model=JS_GATED_MODEL_ID,
schema_builder=SchemaBuilder(SAMPLE_PROMPT, SAMPLE_RESPONSE),
role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"],
sagemaker_session=sagemaker_session,
)


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="The goal of these test are to test the serving components of our feature",
Expand Down Expand Up @@ -75,3 +88,33 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
)
if caught_ex:
raise caught_ex


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="The goal of these test are to test the serving components of our feature",
)
@pytest.mark.slow_test
def test_happy_js_gated_model(happy_model_builder_gated_model, gpu_instance_type):
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
happy_model_builder_gated_model.build()


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="The goal of these test are to test the serving components of our feature",
)
@pytest.mark.slow_test
def test_js_gated_model_throws(happy_model_builder_gated_model, gpu_instance_type):
logger.info("Running in Local mode...")
model_builder = ModelBuilder(
model=JS_GATED_MODEL_ID,
schema_builder=SchemaBuilder(SAMPLE_PROMPT, SAMPLE_RESPONSE),
mode=Mode.LOCAL_CONTAINER,
)

with pytest.raises(
ValueError,
match="JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode.",
):
model_builder.build()
82 changes: 82 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@
"123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
)

mock_model_data = {
"S3DataSource": {
"S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
"/artifacts/inference-prepack/v1.0.0/",
"S3DataType": "S3Prefix",
"CompressionType": "None",
}
}
mock_model_data_str = (
"s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
"/artifacts/inference-prepack/v1.0.0/"
)


class TestJumpStartBuilder(unittest.TestCase):
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
Expand Down Expand Up @@ -527,3 +540,72 @@ def test_tune_for_djl_js_endpoint_mode_ex(

tuned_model = model.tune()
assert tuned_model == model

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
return_value=True,
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
return_value=MagicMock(),
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
return_value=({"model_type": "t5", "n_head": 71}, True),
)
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
@patch(
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
)
def test_js_gated_model_in_endpoint_mode(
self,
mock_get_nb_instance,
mock_get_ram_usage_mb,
mock_prepare_for_tgi,
mock_pre_trained_model,
mock_is_jumpstart_model,
mock_telemetry,
):
builder = ModelBuilder(
model="facebook/galactica-mock-model-id",
schema_builder=mock_schema_builder,
mode=Mode.SAGEMAKER_ENDPOINT,
)

mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
mock_pre_trained_model.return_value.model_data = mock_model_data

model = builder.build()

assert model is not None

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
return_value=True,
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
return_value=MagicMock(),
)
def test_js_gated_model_in_local_mode(
self,
mock_pre_trained_model,
mock_is_jumpstart_model,
mock_telemetry,
):
builder = ModelBuilder(
model="huggingface-llm-zephyr-7b-gemma",
schema_builder=mock_schema_builder,
mode=Mode.LOCAL_CONTAINER,
)

mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
mock_pre_trained_model.return_value.model_data = mock_model_data_str

self.assertRaisesRegex(
ValueError,
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode.",
lambda: builder.build(),
)