Skip to content

Commit 0815ab5

Browse files
makungaj1Jonathan Makunga
authored and
root
committed
fix: JS Model with non-TGI/non-DJL deployment failure (aws#4688)
* Debug * Debug * Debug * Debug * Debug * Debug * fix docstyle * Refactoring * Add Integ tests --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 30db18d commit 0815ab5

File tree

2 files changed

+92
-48
lines changed

2 files changed

+92
-48
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

+42-48
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
300300
returns:
301301
Tuned Model.
302302
"""
303+
if self.mode == Mode.SAGEMAKER_ENDPOINT:
304+
logger.warning(
305+
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
306+
)
307+
return self.pysdk_model
303308

304309
num_shard_env_var_name = "SM_NUM_GPUS"
305310
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
@@ -468,58 +473,47 @@ def _build_for_jumpstart(self):
468473
self.secret_key = None
469474
self.jumpstart = True
470475

471-
self.pysdk_model = self._create_pre_trained_js_model()
472-
self.pysdk_model.tune = lambda *args, **kwargs: self._default_tune()
473-
474-
logger.info(
475-
"JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri
476-
)
477-
478-
if self.mode != Mode.SAGEMAKER_ENDPOINT:
479-
if self._is_gated_model(self.pysdk_model):
480-
raise ValueError(
481-
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
482-
)
483-
484-
if "djl-inference" in self.pysdk_model.image_uri:
485-
logger.info("Building for DJL JumpStart Model ID...")
486-
self.model_server = ModelServer.DJL_SERVING
487-
self.image_uri = self.pysdk_model.image_uri
488-
489-
self._build_for_djl_jumpstart()
490-
491-
self.pysdk_model.tune = self.tune_for_djl_jumpstart
492-
elif "tgi-inference" in self.pysdk_model.image_uri:
493-
logger.info("Building for TGI JumpStart Model ID...")
494-
self.model_server = ModelServer.TGI
495-
self.image_uri = self.pysdk_model.image_uri
496-
497-
self._build_for_tgi_jumpstart()
476+
pysdk_model = self._create_pre_trained_js_model()
477+
image_uri = pysdk_model.image_uri
498478

499-
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
500-
elif "huggingface-pytorch-inference:" in self.pysdk_model.image_uri:
501-
logger.info("Building for MMS JumpStart Model ID...")
502-
self.model_server = ModelServer.MMS
503-
self.image_uri = self.pysdk_model.image_uri
479+
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
504480

505-
self._build_for_mms_jumpstart()
506-
else:
507-
raise ValueError(
508-
"JumpStart Model ID was not packaged "
509-
"with djl-inference, tgi-inference, or mms-inference container."
510-
)
511-
512-
return self.pysdk_model
481+
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
482+
raise ValueError(
483+
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
484+
)
513485

514-
def _default_tune(self):
515-
"""Logs a warning message if tune is invoked on endpoint mode.
486+
if "djl-inference" in image_uri:
487+
logger.info("Building for DJL JumpStart Model ID...")
488+
self.model_server = ModelServer.DJL_SERVING
489+
self.pysdk_model = pysdk_model
490+
self.image_uri = self.pysdk_model.image_uri
491+
492+
self._build_for_djl_jumpstart()
493+
494+
self.pysdk_model.tune = self.tune_for_djl_jumpstart
495+
elif "tgi-inference" in image_uri:
496+
logger.info("Building for TGI JumpStart Model ID...")
497+
self.model_server = ModelServer.TGI
498+
self.pysdk_model = pysdk_model
499+
self.image_uri = self.pysdk_model.image_uri
500+
501+
self._build_for_tgi_jumpstart()
502+
503+
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
504+
elif "huggingface-pytorch-inference:" in image_uri:
505+
logger.info("Building for MMS JumpStart Model ID...")
506+
self.model_server = ModelServer.MMS
507+
self.pysdk_model = pysdk_model
508+
self.image_uri = self.pysdk_model.image_uri
509+
510+
self._build_for_mms_jumpstart()
511+
elif self.mode != Mode.SAGEMAKER_ENDPOINT:
512+
raise ValueError(
513+
"JumpStart Model ID was not packaged "
514+
"with djl-inference, tgi-inference, or mms-inference container."
515+
)
516516

517-
Returns:
518-
Jumpstart Model: ``This`` model
519-
"""
520-
logger.warning(
521-
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
522-
)
523517
return self.pysdk_model
524518

525519
def _is_gated_model(self, model) -> bool:

tests/integ/sagemaker/serve/test_serve_js_happy.py

+50
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
3535
ROLE_NAME = "SageMakerRole"
3636

37+
SAMPLE_MMS_PROMPT = [
38+
"How cute your dog is!",
39+
"Your dog is so cute.",
40+
"The mitochondria is the powerhouse of the cell.",
41+
]
42+
SAMPLE_MMS_RESPONSE = {"embedding": []}
43+
JS_MMS_MODEL_ID = "huggingface-sentencesimilarity-bge-m3"
44+
3745

3846
@pytest.fixture
3947
def happy_model_builder(sagemaker_session):
@@ -46,6 +54,17 @@ def happy_model_builder(sagemaker_session):
4654
)
4755

4856

57+
@pytest.fixture
58+
def happy_mms_model_builder(sagemaker_session):
59+
iam_client = sagemaker_session.boto_session.client("iam")
60+
return ModelBuilder(
61+
model=JS_MMS_MODEL_ID,
62+
schema_builder=SchemaBuilder(SAMPLE_MMS_PROMPT, SAMPLE_MMS_RESPONSE),
63+
role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"],
64+
sagemaker_session=sagemaker_session,
65+
)
66+
67+
4968
@pytest.mark.skipif(
5069
PYTHON_VERSION_IS_NOT_310,
5170
reason="The goal of these test are to test the serving components of our feature",
@@ -75,3 +94,34 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
7594
)
7695
if caught_ex:
7796
raise caught_ex
97+
98+
99+
@pytest.mark.skipif(
100+
PYTHON_VERSION_IS_NOT_310,
101+
reason="The goal of these test are to test the serving components of our feature",
102+
)
103+
@pytest.mark.slow_test
104+
def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type):
105+
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
106+
caught_ex = None
107+
model = happy_mms_model_builder.build()
108+
109+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
110+
try:
111+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
112+
predictor = model.deploy(instance_type=gpu_instance_type, endpoint_logging=False)
113+
logger.info("Endpoint successfully deployed.")
114+
115+
updated_sample_input = happy_mms_model_builder.schema_builder.sample_input
116+
117+
predictor.predict(updated_sample_input)
118+
except Exception as e:
119+
caught_ex = e
120+
finally:
121+
cleanup_model_resources(
122+
sagemaker_session=happy_mms_model_builder.sagemaker_session,
123+
model_name=model.name,
124+
endpoint_name=model.endpoint_name,
125+
)
126+
if caught_ex:
127+
raise caught_ex

0 commit comments

Comments
 (0)