Skip to content

feat: Introduce HF Transformers to ModelBuilder #4368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 2, 2024
19 changes: 14 additions & 5 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sagemaker.serve.builder.djl_builder import DJL
from sagemaker.serve.builder.tgi_builder import TGI
from sagemaker.serve.builder.jumpstart_builder import JumpStart
from sagemaker.serve.builder.transformers_builder import Transformers
from sagemaker.predictor import Predictor
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
from sagemaker.serve.spec.inference_spec import InferenceSpec
Expand All @@ -53,6 +54,7 @@
from sagemaker.serve.validations.check_image_and_hardware_type import (
validate_image_uri_and_hardware,
)
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata

logger = logging.getLogger(__name__)

Expand All @@ -65,7 +67,7 @@

# pylint: disable=attribute-defined-outside-init
@dataclass
class ModelBuilder(Triton, DJL, JumpStart, TGI):
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
"""Class that builds a deployable model.

Args:
Expand Down Expand Up @@ -125,8 +127,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI):
in order for model builder to build the artifacts correctly (according
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
``TRITON``, and ``TGI``.

``TRITON``, and``TGI``.
"""

model_path: Optional[str] = field(
Expand Down Expand Up @@ -535,7 +536,7 @@ def wrapper(*args, **kwargs):
return wrapper

# Model Builder is a class to build the model for deployment.
# It supports three modes of deployment
# It supports two modes of deployment
# 1/ SageMaker Endpoint
# 2/ Local launch with container
def build(
Expand Down Expand Up @@ -577,12 +578,20 @@ def build(
)

self.serve_settings = self._get_serve_setting()

hf_model_md = get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

if isinstance(self.model, str):
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
if self._is_djl():
return self._build_for_djl()
return self._build_for_tgi()
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
return self._build_for_tgi()
else:
return self._build_for_transformers()

self._build_validations()

Expand Down
280 changes: 280 additions & 0 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Transformers build logic with model builder"""
from __future__ import absolute_import
import logging
from abc import ABC, abstractmethod
from typing import Type
from packaging.version import Version

from sagemaker.model import Model
from sagemaker import image_uris
from sagemaker.serve.utils.local_hardware import (
_get_nb_instance,
)
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.serve.model_server.multi_model_server.prepare import (
_create_dir_structure,
)
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.base_predictor import PredictorBase
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata

logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 1800


"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub
"""


# pylint: disable=W0108
class Transformers(ABC):
"""Transformers build logic with ModelBuilder()"""

def __init__(self):
self.model = None
self.serve_settings = None
self.sagemaker_session = None
self.model_path = None
self.dependencies = None
self.modes = None
self.mode = None
self.model_server = None
self.image_uri = None
self._original_deploy = None
self.hf_model_config = None
self._default_data_type = None
self.pysdk_model = None
self.env_vars = None
self.nb_instance_type = None
self.ram_usage_model_load = None
self.secret_key = None
self.role_arn = None
self.py_version = None
self.tensorflow_version = None
self.pytorch_version = None
self.instance_type = None
self.schema_builder = None

@abstractmethod
def _prepare_for_mode(self):
"""Abstract method"""

def _create_transformers_model(self) -> Type[Model]:
"""Initializes the model after fetching image

1. Get the metadata for deciding framework
2. Get the supported hugging face versions
3. Create model
4. Fetch image

Returns:
pysdk_model: Corresponding model instance
"""

hf_model_md = get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
hf_config = image_uris.config_for_framework("huggingface").get("inference")
config = hf_config["versions"]
base_hf_version = sorted(config.keys(), key=lambda v: Version(v))[0]

if hf_model_md is None:
raise ValueError("Could not fetch HF metadata")

if "pytorch" in hf_model_md.get("tags"):
self.pytorch_version = self._get_supported_version(
hf_config, base_hf_version, "pytorch"
)
self.py_version = config[base_hf_version]["pytorch" + self.pytorch_version].get(
"py_versions"
)[-1]
pysdk_model = HuggingFaceModel(
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
py_version=self.py_version,
transformers_version=base_hf_version,
pytorch_version=self.pytorch_version,
)
elif "keras" in hf_model_md.get("tags") or "tensorflow" in hf_model_md.get("tags"):
self.tensorflow_version = self._get_supported_version(
hf_config, base_hf_version, "tensorflow"
)
self.py_version = config[base_hf_version]["tensorflow" + self.tensorflow_version].get(
"py_versions"
)[-1]
pysdk_model = HuggingFaceModel(
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
py_version=self.py_version,
transformers_version=base_hf_version,
tensorflow_version=self.tensorflow_version,
)

if self.mode == Mode.LOCAL_CONTAINER:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, "local"
)
else:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, self.instance_type
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

self._original_deploy = pysdk_model.deploy
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
return pysdk_model

@_capture_telemetry("transformers.deploy")
def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
"""Returns predictor depending on local or sagemaker endpoint mode

Returns:
TransformersLocalModePredictor: During local mode deployment
"""
timeout = kwargs.get("model_data_download_timeout")
if timeout:
self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)})

if "mode" in kwargs and kwargs.get("mode") != self.mode:
overwrite_mode = kwargs.get("mode")
# mode overwritten by customer during model.deploy()
logger.warning(
"Deploying in %s Mode, overriding existing configurations set for %s mode",
overwrite_mode,
self.mode,
)

if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)

self._set_instance()

serializer = self.schema_builder.input_serializer
deserializer = self.schema_builder._output_deserializer
if self.mode == Mode.LOCAL_CONTAINER:
timeout = kwargs.get("model_data_download_timeout")

predictor = TransformersLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)

self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
self.image_uri,
timeout if timeout else DEFAULT_TIMEOUT,
None,
predictor,
self.pysdk_model.env,
jumpstart=False,
)
return predictor

if "mode" in kwargs:
del kwargs["mode"]
if "role" in kwargs:
self.pysdk_model.role = kwargs.get("role")
del kwargs["role"]

# set model_data to uncompressed s3 dict
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
self.env_vars.update(env_vars)
self.pysdk_model.env.update(self.env_vars)

if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True

if "initial_instance_count" not in kwargs:
kwargs.update({"initial_instance_count": 1})

predictor = self._original_deploy(*args, **kwargs)

predictor.serializer = serializer
predictor.deserializer = deserializer
return predictor

def _build_transformers_env(self):
"""Build model for hugging face deployment using"""
self.nb_instance_type = _get_nb_instance()

_create_dir_structure(self.model_path)
if not hasattr(self, "pysdk_model"):
self.env_vars.update({"HF_MODEL_ID": self.model})

logger.info(self.env_vars)

# TODO: Move to a helper function
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HF_API_TOKEN")
)
else:
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

self.pysdk_model = self._create_transformers_model()

if self.mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()

return self.pysdk_model

def _set_instance(self, **kwargs):
"""Set the instance : Given the detected notebook type or provided instance type"""
if self.mode == Mode.SAGEMAKER_ENDPOINT:
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})
elif self.instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.instance_type})
else:
raise ValueError(
"Instance type must be provided when deploying to SageMaker Endpoint mode."
)
logger.info("Setting instance type to %s", self.instance_type)

def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
"""Uses the hugging face json config to pick supported versions"""
version_config = hf_config.get("versions").get(hugging_face_version)
versions_to_return = list()
for key in list(version_config.keys()):
if key.startswith(base_fw):
base_fw_version = key[len(base_fw) :]
if len(hugging_face_version.split(".")) == 2:
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
versions_to_return.append(base_fw_version)
return sorted(versions_to_return)[0]

def _build_for_transformers(self):
"""Method that triggers model build

Returns:PySDK model
"""
self.secret_key = None
self.model_server = ModelServer.MMS

self._build_transformers_env()

return self.pysdk_model
14 changes: 13 additions & 1 deletion src/sagemaker/serve/mode/local_container_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
from sagemaker.serve.model_server.triton.server import LocalTritonServer
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
from sagemaker.session import Session

logger = logging.getLogger(__name__)
Expand All @@ -31,7 +32,9 @@
)


class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing):
class LocalContainerMode(
LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer
):
"""A class that holds methods to deploy model to a container in local environment"""

def __init__(
Expand Down Expand Up @@ -128,6 +131,15 @@ def create_server(
jumpstart=jumpstart,
)
self._ping_container = self._tgi_deep_ping
elif self.model_server == ModelServer.MMS:
self._start_serving(
client=self.client,
image=image,
model_path=model_path if model_path else self.model_path,
secret_key=secret_key,
env_vars=env_vars if env_vars else self.env_vars,
)
self._ping_container = self._multi_model_server_deep_ping

# allow some time for container to be ready
time.sleep(10)
Expand Down
Loading