Skip to content

Feat: Add TEI support for ModelBuilder #4694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 21, 2024
2 changes: 1 addition & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
in order for model builder to build the artifacts correctly (according
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
``TRITON``, and``TGI``.
``TRITON``,``TGI``, and ``TEI``.
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
new models without task metadata in the Hub, adding unsupported task types will throw
Expand Down
18 changes: 10 additions & 8 deletions src/sagemaker/serve/builder/tei_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_get_nb_instance,
)
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
from sagemaker.serve.utils.predictors import TeiLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
Expand Down Expand Up @@ -74,16 +74,16 @@ def _prepare_for_mode(self):
def _get_client_translators(self):
"""Placeholder docstring"""

def _set_to_tgi(self):
def _set_to_tei(self):
"""Placeholder docstring"""
if self.model_server != ModelServer.TGI:
if self.model_server != ModelServer.TEI:
messaging = (
"HuggingFace Model ID support on model server: "
f"{self.model_server} is not currently supported. "
f"Defaulting to {ModelServer.TGI}"
f"Defaulting to {ModelServer.TEI}"
)
logger.warning(messaging)
self.model_server = ModelServer.TGI
self.model_server = ModelServer.TEI

def _create_tei_model(self, **kwargs) -> Type[Model]:
"""Placeholder docstring"""
Expand Down Expand Up @@ -142,7 +142,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
if self.mode == Mode.LOCAL_CONTAINER:
timeout = kwargs.get("model_data_download_timeout")

predictor = TgiLocalModePredictor(
predictor = TeiLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)

Expand Down Expand Up @@ -180,7 +180,9 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True

if not self.nb_instance_type and "instance_type" not in kwargs:
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})
elif not self.nb_instance_type and "instance_type" not in kwargs:
raise ValueError(
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
)
Expand Down Expand Up @@ -216,7 +218,7 @@ def _build_for_tei(self):
"""Placeholder docstring"""
self.secret_key = None

self._set_to_tgi()
self._set_to_tei()

self.pysdk_model = self._build_for_hf_tei()
return self.pysdk_model
15 changes: 15 additions & 0 deletions src/sagemaker/serve/mode/local_container_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
from sagemaker.serve.model_server.triton.server import LocalTritonServer
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
from sagemaker.serve.model_server.tei.server import LocalTeiServing
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
from sagemaker.session import Session

Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
self.container = None
self.secret_key = None
self._ping_container = None
self._invoke_serving = None

def load(self, model_path: str = None):
"""Placeholder docstring"""
Expand Down Expand Up @@ -156,6 +158,19 @@ def create_server(
env_vars=env_vars if env_vars else self.env_vars,
)
self._ping_container = self._tensorflow_serving_deep_ping
elif self.model_server == ModelServer.TEI:
tei_serving = LocalTeiServing()
tei_serving._start_tei_serving(
client=self.client,
image=image,
model_path=model_path if model_path else self.model_path,
secret_key=secret_key,
env_vars=env_vars if env_vars else self.env_vars,
)
tei_serving.schema_builder = self.schema_builder
self.container = tei_serving.container
self._ping_container = tei_serving._tei_deep_ping
self._invoke_serving = tei_serving._invoke_tei_serving

# allow some time for container to be ready
time.sleep(10)
Expand Down
27 changes: 21 additions & 6 deletions src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from typing import Type

from sagemaker.serve.model_server.tei.server import SageMakerTeiServing
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
from sagemaker.session import Session
from sagemaker.serve.utils.types import ModelServer
Expand Down Expand Up @@ -37,6 +38,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe
self.inference_spec = inference_spec
self.model_server = model_server

self._tei_serving = SageMakerTeiServing()

def load(self, model_path: str):
"""Placeholder docstring"""
path = Path(model_path)
Expand Down Expand Up @@ -66,8 +69,9 @@ def prepare(
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
) from e

upload_artifacts = None
if self.model_server == ModelServer.TORCHSERVE:
return self._upload_torchserve_artifacts(
upload_artifacts = self._upload_torchserve_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
Expand All @@ -76,7 +80,7 @@ def prepare(
)

if self.model_server == ModelServer.TRITON:
return self._upload_triton_artifacts(
upload_artifacts = self._upload_triton_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
Expand All @@ -85,15 +89,15 @@ def prepare(
)

if self.model_server == ModelServer.DJL_SERVING:
return self._upload_djl_artifacts(
upload_artifacts = self._upload_djl_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
)

if self.model_server == ModelServer.TGI:
return self._upload_tgi_artifacts(
upload_artifacts = self._upload_tgi_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
Expand All @@ -102,20 +106,31 @@ def prepare(
)

if self.model_server == ModelServer.MMS:
return self._upload_server_artifacts(
upload_artifacts = self._upload_server_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
)

if self.model_server == ModelServer.TENSORFLOW_SERVING:
return self._upload_tensorflow_serving_artifacts(
upload_artifacts = self._upload_tensorflow_serving_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
)

if self.model_server == ModelServer.TEI:
upload_artifacts = self._tei_serving._upload_tei_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
)

if upload_artifacts:
return upload_artifacts

raise ValueError("%s model server is not supported" % self.model_server)
Empty file.
160 changes: 160 additions & 0 deletions src/sagemaker/serve/model_server/tei/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Module for Local TEI Serving"""

from __future__ import absolute_import

import requests
import logging
from pathlib import Path
from docker.types import DeviceRequest
from sagemaker import Session, fw_utils
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
from sagemaker.base_predictor import PredictorBase
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
from sagemaker.s3 import S3Uploader
from sagemaker.local.utils import get_docker_host


MODE_DIR_BINDING = "/opt/ml/model/"
_SHM_SIZE = "2G"
_DEFAULT_ENV_VARS = {
"TRANSFORMERS_CACHE": "/opt/ml/model/",
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
}

logger = logging.getLogger(__name__)


class LocalTeiServing:
"""LocalTeiServing class"""

def _start_tei_serving(
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
):
"""Starts a local tei serving container.

Args:
client: Docker client
image: Image to use
model_path: Path to the model
secret_key: Secret key to use for authentication
env_vars: Environment variables to set
"""
if env_vars and secret_key:
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

self.container = client.containers.run(
image,
shm_size=_SHM_SIZE,
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
network_mode="host",
detach=True,
auto_remove=True,
volumes={
Path(model_path).joinpath("code"): {
"bind": MODE_DIR_BINDING,
"mode": "rw",
},
},
environment=_update_env_vars(env_vars),
)

def _invoke_tei_serving(self, request: object, content_type: str, accept: str):
"""Invokes a local tei serving container.

Args:
request: Request to send
content_type: Content type to use
accept: Accept to use
"""
try:
response = requests.post(
f"http://{get_docker_host()}:8080/invocations",
data=request,
headers={"Content-Type": content_type, "Accept": accept},
timeout=600,
)
response.raise_for_status()
return response.content
except Exception as e:
raise Exception("Unable to send request to the local container server") from e

def _tei_deep_ping(self, predictor: PredictorBase):
"""Checks if the local tei serving container is up and running.

If the container is not up and running, it will raise an exception.
"""
response = None
try:
response = predictor.predict(self.schema_builder.sample_input)
return (True, response)
# pylint: disable=broad-except
except Exception as e:
if "422 Client Error: Unprocessable Entity for url" in str(e):
raise LocalModelInvocationException(str(e))
return (False, response)

return (True, response)


class SageMakerTeiServing:
"""SageMakerTeiServing class"""

def _upload_tei_artifacts(
self,
model_path: str,
sagemaker_session: Session,
s3_model_data_url: str = None,
image: str = None,
env_vars: dict = None,
):
"""Uploads the model artifacts to S3.

Args:
model_path: Path to the model
sagemaker_session: SageMaker session
s3_model_data_url: S3 model data URL
image: Image to use
env_vars: Environment variables to set
"""
if s3_model_data_url:
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
else:
bucket, key_prefix = None, None

code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)

bucket, code_key_prefix = determine_bucket_and_prefix(
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
)

code_dir = Path(model_path).joinpath("code")

s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code")

logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location)

model_data_url = S3Uploader.upload(
str(code_dir),
s3_location,
None,
sagemaker_session,
)

model_data = {
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": model_data_url + "/",
}
}

return (model_data, _update_env_vars(env_vars))


def _update_env_vars(env_vars: dict) -> dict:
"""Placeholder docstring"""
updated_env_vars = {}
updated_env_vars.update(_DEFAULT_ENV_VARS)
if env_vars:
updated_env_vars.update(env_vars)
return updated_env_vars
Loading