Skip to content

Fix: JS Model with non-TGI/non-DJL deployment failure #4688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 16, 2024
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 42 additions & 48 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
returns:
Tuned Model.
"""
if self.mode == Mode.SAGEMAKER_ENDPOINT:
logger.warning(
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
)
return self.pysdk_model

num_shard_env_var_name = "SM_NUM_GPUS"
if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
Expand Down Expand Up @@ -468,58 +473,47 @@ def _build_for_jumpstart(self):
self.secret_key = None
self.jumpstart = True

self.pysdk_model = self._create_pre_trained_js_model()
self.pysdk_model.tune = lambda *args, **kwargs: self._default_tune()

logger.info(
"JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri
)

if self.mode != Mode.SAGEMAKER_ENDPOINT:
if self._is_gated_model(self.pysdk_model):
raise ValueError(
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
)

if "djl-inference" in self.pysdk_model.image_uri:
logger.info("Building for DJL JumpStart Model ID...")
self.model_server = ModelServer.DJL_SERVING
self.image_uri = self.pysdk_model.image_uri

self._build_for_djl_jumpstart()

self.pysdk_model.tune = self.tune_for_djl_jumpstart
elif "tgi-inference" in self.pysdk_model.image_uri:
logger.info("Building for TGI JumpStart Model ID...")
self.model_server = ModelServer.TGI
self.image_uri = self.pysdk_model.image_uri

self._build_for_tgi_jumpstart()
pysdk_model = self._create_pre_trained_js_model()
image_uri = pysdk_model.image_uri

self.pysdk_model.tune = self.tune_for_tgi_jumpstart
elif "huggingface-pytorch-inference:" in self.pysdk_model.image_uri:
logger.info("Building for MMS JumpStart Model ID...")
self.model_server = ModelServer.MMS
self.image_uri = self.pysdk_model.image_uri
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)

self._build_for_mms_jumpstart()
else:
raise ValueError(
"JumpStart Model ID was not packaged "
"with djl-inference, tgi-inference, or mms-inference container."
)

return self.pysdk_model
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError(
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
)

def _default_tune(self):
"""Logs a warning message if tune is invoked on endpoint mode.
if "djl-inference" in image_uri:
logger.info("Building for DJL JumpStart Model ID...")
self.model_server = ModelServer.DJL_SERVING
self.pysdk_model = pysdk_model
self.image_uri = self.pysdk_model.image_uri

self._build_for_djl_jumpstart()

self.pysdk_model.tune = self.tune_for_djl_jumpstart
elif "tgi-inference" in image_uri:
logger.info("Building for TGI JumpStart Model ID...")
self.model_server = ModelServer.TGI
self.pysdk_model = pysdk_model
self.image_uri = self.pysdk_model.image_uri

self._build_for_tgi_jumpstart()

self.pysdk_model.tune = self.tune_for_tgi_jumpstart
elif "huggingface-pytorch-inference:" in image_uri:
logger.info("Building for MMS JumpStart Model ID...")
self.model_server = ModelServer.MMS
self.pysdk_model = pysdk_model
self.image_uri = self.pysdk_model.image_uri

self._build_for_mms_jumpstart()
elif self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError(
"JumpStart Model ID was not packaged "
"with djl-inference, tgi-inference, or mms-inference container."
)

Returns:
Jumpstart Model: ``This`` model
"""
logger.warning(
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
)
return self.pysdk_model

def _is_gated_model(self, model) -> bool:
Expand Down