Skip to content

Commit 734ec60

Browse files
committed
feat: Introduce HF Transformers to ModelBuilder
1 parent 086c946 commit 734ec60

File tree

15 files changed

+663
-8
lines changed

15 files changed

+663
-8
lines changed

src/sagemaker/serve/builder/djl_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from sagemaker.serve.model_server.djl_serving.prepare import (
4848
prepare_for_djl_serving,
49-
_create_dir_structure,
49+
_create_dir_structure
5050
)
5151
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
5252
from sagemaker.serve.utils.types import ModelServer, _DjlEngine
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""HuggingFace DLC specific model builder"""
14+
from __future__ import absolute_import
15+
import logging
16+
from packaging.version import Version
17+
from typing import Type
18+
from abc import ABC, abstractmethod
19+
20+
from sagemaker.model import Model
21+
from sagemaker import Session, image_uris
22+
from sagemaker.serve.utils.local_hardware import (
23+
_get_nb_instance,
24+
_get_ram_usage_mb,
25+
_get_gpu_info,
26+
_get_gpu_info_fallback,
27+
)
28+
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
29+
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
30+
from sagemaker.serve.model_server.hf_dlc.prepare import (
31+
_create_dir_structure,
32+
)
33+
from sagemaker.serve.utils.predictors import HfDLCLocalModePredictor
34+
from sagemaker.serve.utils.types import ModelServer
35+
from sagemaker.serve.mode.function_pointers import Mode
36+
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
37+
from sagemaker.base_predictor import PredictorBase
38+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
39+
40+
logger = logging.getLogger(__name__)
41+
DEFAULT_TIMEOUT = 1800
42+
43+
44+
class HuggingFaceDLC(ABC):
45+
"""HuggingFace DLC build logic for ModelBuilder()"""
46+
47+
def __init__(self):
48+
self.model = None
49+
self.serve_settings = None
50+
self.sagemaker_session = None
51+
self.model_path = None
52+
self.dependencies = None
53+
self.modes = None
54+
self.mode = None
55+
self.model_server = None
56+
self.image_uri = None
57+
self._original_deploy = None
58+
self.hf_model_config = None
59+
self._default_data_type = None
60+
self.pysdk_model = None
61+
self.env_vars = None
62+
self.nb_instance_type = None
63+
self.ram_usage_model_load = None
64+
self.secret_key = None
65+
self.role_arn = None
66+
self.py_version = None
67+
self.tensorflow_version = None
68+
self.pytorch_version = None
69+
70+
@abstractmethod
71+
def _prepare_for_mode(self):
72+
"""Abstract method"""
73+
74+
def _create_hf_dlc_model(self) -> Type[Model]:
75+
"""Initializes the model after fetching image
76+
77+
1. Get the metadata for deciding framework
78+
2. Get the supported hugging face versions
79+
3. Create model
80+
4. Fetch image
81+
82+
Returns:
83+
pysdk_model: Corresponding model instance
84+
"""
85+
86+
hf_model_md = get_huggingface_model_metadata(self.model,
87+
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
88+
hf_config = image_uris.config_for_framework("huggingface").get("inference")
89+
config = hf_config["versions"]
90+
base_hf_version = sorted(config.keys(), key=lambda v: Version(v))[0]
91+
92+
if hf_model_md is None:
93+
raise ValueError("Could not fetch HF metadata")
94+
95+
if 'pytorch' in hf_model_md.get("tags"):
96+
self.pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch")
97+
self.py_version = config[base_hf_version]["pytorch"+self.pytorch_version].get("py_versions")[-1]
98+
pysdk_model = HuggingFaceModel(
99+
env=self.env_vars,
100+
role=self.role_arn,
101+
sagemaker_session=self.sagemaker_session,
102+
py_version=self.py_version,
103+
transformers_version=base_hf_version,
104+
pytorch_version=self.pytorch_version
105+
)
106+
elif 'keras' in hf_model_md.get("tags") or 'tensorflow' in hf_model_md.get("tags"):
107+
self.tensorflow_version = self._get_supported_version(hf_config, base_hf_version, "tensorflow")
108+
self.py_version = config[base_hf_version]["tensorflow"+self.tensorflow_version].get("py_versions")[-1]
109+
pysdk_model = HuggingFaceModel(
110+
env=self.env_vars,
111+
role=self.role_arn,
112+
sagemaker_session=self.sagemaker_session,
113+
py_version=self.py_version,
114+
transformers_version=base_hf_version,
115+
tensorflow_version=self.tensorflow_version
116+
)
117+
118+
if self.mode == Mode.LOCAL_CONTAINER:
119+
self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, "local")
120+
else:
121+
self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, self.instance_type)
122+
123+
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
124+
125+
self._original_deploy = pysdk_model.deploy
126+
pysdk_model.deploy = self._hf_dlc_model_builder_deploy_wrapper
127+
return pysdk_model
128+
129+
@_capture_telemetry("hf_dlc.deploy")
130+
def _hf_dlc_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
131+
"""Returns predictor depending on local or sagemaker endpoint mode
132+
133+
Returns:
134+
HfDLCLocalModePredictor: During local mode deployment
135+
"""
136+
timeout = kwargs.get("model_data_download_timeout")
137+
if timeout:
138+
self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
139+
140+
if "mode" in kwargs and kwargs.get("mode") != self.mode:
141+
overwrite_mode = kwargs.get("mode")
142+
# mode overwritten by customer during model.deploy()
143+
logger.warning(
144+
"Deploying in %s Mode, overriding existing configurations set for %s mode",
145+
overwrite_mode,
146+
self.mode,
147+
)
148+
149+
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
150+
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
151+
elif overwrite_mode == Mode.LOCAL_CONTAINER:
152+
self._prepare_for_mode()
153+
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
154+
else:
155+
raise ValueError("Mode %s is not supported!" % overwrite_mode)
156+
157+
self._set_instance()
158+
159+
serializer = self.schema_builder.input_serializer
160+
deserializer = self.schema_builder._output_deserializer
161+
if self.mode == Mode.LOCAL_CONTAINER:
162+
timeout = kwargs.get("model_data_download_timeout")
163+
164+
predictor = HfDLCLocalModePredictor(
165+
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
166+
)
167+
168+
ram_usage_before = _get_ram_usage_mb()
169+
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
170+
self.image_uri,
171+
timeout if timeout else DEFAULT_TIMEOUT,
172+
None,
173+
predictor,
174+
self.pysdk_model.env,
175+
jumpstart=False,
176+
)
177+
178+
ram_usage_after = _get_ram_usage_mb()
179+
self.ram_usage_model_load = max(ram_usage_after - ram_usage_before, 0)
180+
181+
return predictor
182+
183+
if "mode" in kwargs:
184+
del kwargs["mode"]
185+
if "role" in kwargs:
186+
self.pysdk_model.role = kwargs.get("role")
187+
del kwargs["role"]
188+
189+
# set model_data to uncompressed s3 dict
190+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
191+
self.env_vars.update(env_vars)
192+
self.pysdk_model.env.update(self.env_vars)
193+
194+
if "endpoint_logging" not in kwargs:
195+
kwargs["endpoint_logging"] = True
196+
197+
if "initial_instance_count" not in kwargs:
198+
kwargs.update({"initial_instance_count": 1})
199+
200+
predictor = self._original_deploy(*args, **kwargs)
201+
202+
predictor.serializer = serializer
203+
predictor.deserializer = deserializer
204+
return predictor
205+
206+
def _build_for_hugging_face_dlc(self):
207+
"""Build model for hugging face deployment using
208+
209+
Returns:
210+
HfDLCLocalModePredictor: During local mode deployment
211+
"""
212+
self.nb_instance_type = _get_nb_instance()
213+
214+
_create_dir_structure(self.model_path)
215+
if not hasattr(self, "pysdk_model"):
216+
self.env_vars.update({"HF_MODEL_ID": self.model})
217+
218+
logger.info(self.env_vars)
219+
220+
# TODO: Move to a helper function
221+
if hasattr(self.env_vars, "HF_API_TOKEN"):
222+
self.hf_model_config = _get_model_config_properties_from_hf(
223+
self.model, self.env_vars.get("HF_API_TOKEN")
224+
)
225+
else:
226+
self.hf_model_config = _get_model_config_properties_from_hf(
227+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
228+
229+
self.pysdk_model = self._create_hf_dlc_model()
230+
231+
if self.mode == Mode.LOCAL_CONTAINER:
232+
self._prepare_for_mode()
233+
234+
return self.pysdk_model
235+
236+
def _set_instance(self, **kwargs):
237+
"""Set the instance
238+
Given the detected notebook type or provided instance type
239+
"""
240+
if self.mode == Mode.SAGEMAKER_ENDPOINT:
241+
if self.nb_instance_type and "instance_type" not in kwargs:
242+
kwargs.update({"instance_type": self.nb_instance_type})
243+
elif self.instance_type and "instance_type" not in kwargs:
244+
kwargs.update({"instance_type": self.instance_type})
245+
else:
246+
raise ValueError(
247+
"Instance type must be provided when deploying to SageMaker Endpoint mode."
248+
)
249+
logger.info("Setting instance type to %s", self.instance_type)
250+
return
251+
252+
def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
253+
"""
254+
Uses the hugging face json config to pick supported versions
255+
"""
256+
version_config = hf_config.get("versions").get(hugging_face_version)
257+
versions_to_return = list()
258+
for key in list(version_config.keys()):
259+
if key.startswith(base_fw):
260+
base_fw_version = key[len(base_fw):]
261+
if len(hugging_face_version.split(".")) == 2:
262+
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
263+
versions_to_return.append(base_fw_version)
264+
return sorted(versions_to_return)[0]
265+
266+
def _build_for_hf_dlc(self):
267+
"""Method that triggers model build
268+
269+
Returns:PySDK model
270+
"""
271+
self.secret_key = None
272+
self.model_server = ModelServer.HuggingFaceDLC
273+
self.pysdk_model = self._build_for_hugging_face_dlc()
274+
return self.pysdk_model

src/sagemaker/serve/builder/model_builder.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.serve.builder.djl_builder import DJL
3535
from sagemaker.serve.builder.tgi_builder import TGI
3636
from sagemaker.serve.builder.jumpstart_builder import JumpStart
37+
from sagemaker.serve.builder.hf_dlc_builder import HuggingFaceDLC
3738
from sagemaker.predictor import Predictor
3839
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
3940
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -53,19 +54,21 @@
5354
from sagemaker.serve.validations.check_image_and_hardware_type import (
5455
validate_image_uri_and_hardware,
5556
)
57+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
5658

5759
logger = logging.getLogger(__name__)
5860

5961
supported_model_server = {
6062
ModelServer.TORCHSERVE,
6163
ModelServer.TRITON,
6264
ModelServer.DJL_SERVING,
65+
ModelServer.HuggingFaceDLC,
6366
}
6467

6568

6669
# pylint: disable=attribute-defined-outside-init
6770
@dataclass
68-
class ModelBuilder(Triton, DJL, JumpStart, TGI):
71+
class ModelBuilder(Triton, DJL, JumpStart, TGI, HuggingFaceDLC):
6972
"""Class that builds a deployable model.
7073
7174
Args:
@@ -125,7 +128,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI):
125128
in order for model builder to build the artifacts correctly (according
126129
to the model server). Possible values for this argument are
127130
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
128-
``TRITON``, and ``TGI``.
131+
``TRITON``, ``TGI``, and ``HuggingFaceDLC``.
129132
130133
"""
131134

@@ -535,7 +538,7 @@ def wrapper(*args, **kwargs):
535538
return wrapper
536539

537540
# Model Builder is a class to build the model for deployment.
538-
# It supports three modes of deployment
541+
# It supports two modes of deployment
539542
# 1/ SageMaker Endpoint
540543
# 2/ Local launch with container
541544
def build(
@@ -577,12 +580,19 @@ def build(
577580
)
578581

579582
self.serve_settings = self._get_serve_setting()
583+
584+
hf_model_md = get_huggingface_model_metadata(self.model,
585+
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
586+
580587
if isinstance(self.model, str):
581588
if self._is_jumpstart_model_id():
582589
return self._build_for_jumpstart()
583590
if self._is_djl():
584591
return self._build_for_djl()
585-
return self._build_for_tgi()
592+
if hf_model_md.get("pipeline_tag") == "text-generation":
593+
return self._build_for_tgi()
594+
else:
595+
return self._build_for_hf_dlc()
586596

587597
self._build_validations()
588598

src/sagemaker/serve/mode/local_container_mode.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
2020
from sagemaker.serve.model_server.triton.server import LocalTritonServer
2121
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
22+
from sagemaker.serve.model_server.hf_dlc.server import LocalHFDLCServing
2223
from sagemaker.session import Session
2324

2425
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@
3132
)
3233

3334

34-
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing):
35+
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalHFDLCServing):
3536
"""A class that holds methods to deploy model to a container in local environment"""
3637

3738
def __init__(
@@ -128,6 +129,15 @@ def create_server(
128129
jumpstart=jumpstart,
129130
)
130131
self._ping_container = self._tgi_deep_ping
132+
elif self.model_server == ModelServer.HuggingFaceDLC:
133+
self._start_hf_dlc_serving(
134+
client=self.client,
135+
image=image,
136+
model_path=model_path if model_path else self.model_path,
137+
secret_key=secret_key,
138+
env_vars=env_vars if env_vars else self.env_vars,
139+
)
140+
self._ping_container = self._hf_dlc_deep_ping
131141

132142
# allow some time for container to be ready
133143
time.sleep(10)

0 commit comments

Comments
 (0)