Skip to content

Commit 7e0d775

Browse files
makungaj1Jonathan Makunga
authored and
root
committed
feat: JumpStart Gated Model Support in ModelBuilder Local Modes (aws#4567)
* Restrict JS Gated nodels only in SM Endpoint mode * Add tests * Address PR review comments * Resolve PR comments * Resolve PR comments * Resolve PR comments * Resolve PR comments * Fix integ tests * Fix tests * Fix tests * Fix integ tests --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 7ba88ce commit 7e0d775

File tree

3 files changed

+137
-3
lines changed

3 files changed

+137
-3
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,11 @@ def _build_for_jumpstart(self):
443443

444444
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
445445

446+
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
447+
raise ValueError(
448+
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
449+
)
450+
446451
if "djl-inference" in image_uri:
447452
logger.info("Building for DJL JumpStart Model ID...")
448453
self.model_server = ModelServer.DJL_SERVING
@@ -469,3 +474,19 @@ def _build_for_jumpstart(self):
469474
)
470475

471476
return self.pysdk_model
477+
478+
def _is_gated_model(self, model) -> bool:
479+
"""Determine if ``this`` Model is Gated
480+
481+
Args:
482+
model (Model): Jumpstart Model
483+
Returns:
484+
bool: ``True`` if ``this`` Model is Gated
485+
"""
486+
s3_uri = model.model_data
487+
if isinstance(s3_uri, dict):
488+
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
489+
490+
if s3_uri is None:
491+
return False
492+
return "private" in s3_uri

tests/integ/sagemaker/serve/test_serve_js_happy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@
3131
SAMPLE_RESPONSE = [
3232
{"generated_text": "Hello, I'm a language model, and I'm here to help you with your English."}
3333
]
34-
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
34+
JS_GATED_MODEL_ID = "meta-textgeneration-llama-2-7b-f"
3535
ROLE_NAME = "SageMakerRole"
3636

3737

3838
@pytest.fixture
3939
def happy_model_builder(sagemaker_session):
4040
iam_client = sagemaker_session.boto_session.client("iam")
4141
return ModelBuilder(
42-
model=JS_MODEL_ID,
42+
model=JS_GATED_MODEL_ID,
4343
schema_builder=SchemaBuilder(SAMPLE_PROMPT, SAMPLE_RESPONSE),
4444
role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"],
4545
sagemaker_session=sagemaker_session,
@@ -59,7 +59,9 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
5959
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
6060
try:
6161
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
62-
predictor = model.deploy(instance_type=gpu_instance_type, endpoint_logging=False)
62+
predictor = model.deploy(
63+
instance_type=gpu_instance_type, endpoint_logging=False, accept_eula=True
64+
)
6365
logger.info("Endpoint successfully deployed.")
6466

6567
updated_sample_input = happy_model_builder.schema_builder.sample_input

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@
6767
"123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
6868
)
6969

70+
mock_model_data = {
71+
"S3DataSource": {
72+
"S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
73+
"/artifacts/inference-prepack/v1.0.0/",
74+
"S3DataType": "S3Prefix",
75+
"CompressionType": "None",
76+
}
77+
}
78+
mock_model_data_str = (
79+
"s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
80+
"/artifacts/inference-prepack/v1.0.0/"
81+
)
82+
7083

7184
class TestJumpStartBuilder(unittest.TestCase):
7285
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@@ -527,3 +540,101 @@ def test_tune_for_djl_js_endpoint_mode_ex(
527540

528541
tuned_model = model.tune()
529542
assert tuned_model == model
543+
544+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
545+
@patch(
546+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
547+
return_value=True,
548+
)
549+
@patch(
550+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
551+
return_value=MagicMock(),
552+
)
553+
@patch(
554+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
555+
return_value=({"model_type": "t5", "n_head": 71}, True),
556+
)
557+
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
558+
@patch(
559+
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
560+
)
561+
def test_js_gated_model_in_endpoint_mode(
562+
self,
563+
mock_get_nb_instance,
564+
mock_get_ram_usage_mb,
565+
mock_prepare_for_tgi,
566+
mock_pre_trained_model,
567+
mock_is_jumpstart_model,
568+
mock_telemetry,
569+
):
570+
builder = ModelBuilder(
571+
model="facebook/galactica-mock-model-id",
572+
schema_builder=mock_schema_builder,
573+
mode=Mode.SAGEMAKER_ENDPOINT,
574+
)
575+
576+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
577+
mock_pre_trained_model.return_value.model_data = mock_model_data
578+
579+
model = builder.build()
580+
581+
assert model is not None
582+
583+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
584+
@patch(
585+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
586+
return_value=True,
587+
)
588+
@patch(
589+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
590+
return_value=MagicMock(),
591+
)
592+
def test_js_gated_model_in_local_mode(
593+
self,
594+
mock_pre_trained_model,
595+
mock_is_jumpstart_model,
596+
mock_telemetry,
597+
):
598+
builder = ModelBuilder(
599+
model="huggingface-llm-zephyr-7b-gemma",
600+
schema_builder=mock_schema_builder,
601+
mode=Mode.LOCAL_CONTAINER,
602+
)
603+
604+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
605+
mock_pre_trained_model.return_value.model_data = mock_model_data_str
606+
607+
self.assertRaisesRegex(
608+
ValueError,
609+
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode.",
610+
lambda: builder.build(),
611+
)
612+
613+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
614+
@patch(
615+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
616+
return_value=True,
617+
)
618+
@patch(
619+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
620+
return_value=MagicMock(),
621+
)
622+
def test_js_gated_model_ex(
623+
self,
624+
mock_pre_trained_model,
625+
mock_is_jumpstart_model,
626+
mock_telemetry,
627+
):
628+
builder = ModelBuilder(
629+
model="huggingface-llm-zephyr-7b-gemma",
630+
schema_builder=mock_schema_builder,
631+
mode=Mode.LOCAL_CONTAINER,
632+
)
633+
634+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
635+
mock_pre_trained_model.return_value.model_data = None
636+
637+
self.assertRaises(
638+
ValueError,
639+
lambda: builder.build(),
640+
)

0 commit comments

Comments
 (0)