Skip to content

Commit ab48e1c

Browse files
committed
Introduce HF DLC to Model Builder
1 parent f2b47ab commit ab48e1c

File tree

12 files changed

+527
-5
lines changed

12 files changed

+527
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds mixin logic to support deployment of Model ID"""
14+
from __future__ import absolute_import
15+
import logging
16+
from typing import Type
17+
from abc import ABC, abstractmethod
18+
19+
from sagemaker.model import Model
20+
from sagemaker.serve.utils.local_hardware import (
21+
_get_nb_instance,
22+
_get_ram_usage_mb,
23+
_get_gpu_info,
24+
_get_gpu_info_fallback,
25+
)
26+
27+
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
28+
from sagemaker.serve.model_server.djl_serving.utils import (
29+
_get_admissible_tensor_parallel_degrees,
30+
_get_default_tensor_parallel_degree,
31+
)
32+
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
33+
from sagemaker.serve.model_server.hf_dlc.prepare import _create_dir_structure
34+
from sagemaker.serve.utils.predictors import HfDLCLocalModePredictor
35+
from sagemaker.serve.utils.types import ModelServer
36+
from sagemaker.serve.mode.function_pointers import Mode
37+
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
38+
from sagemaker.base_predictor import PredictorBase
39+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
40+
41+
logger = logging.getLogger(__name__)
42+
43+
44+
class HuggingFaceDLC(ABC):
45+
"""HuggingFace DLC build logic for ModelBuilder()"""
46+
47+
def __init__(self):
48+
self.model = None
49+
self.serve_settings = None
50+
self.sagemaker_session = None
51+
self.model_path = None
52+
self.dependencies = None
53+
self.modes = None
54+
self.mode = None
55+
self.model_server = None
56+
self.image_uri = None
57+
self._original_deploy = None
58+
self.hf_model_config = None
59+
self._default_data_type = None
60+
self._default_max_tokens = None
61+
self.pysdk_model = None
62+
self.env_vars = None
63+
self.nb_instance_type = None
64+
self.ram_usage_model_load = None
65+
self.secret_key = None
66+
self.role_arn = None
67+
self.transformers_version = None
68+
self.py_version = None
69+
self.pytorch_version = None
70+
self.tensorflow_version = None
71+
72+
@abstractmethod
73+
def _prepare_for_mode(self):
74+
"""Placeholder docstring"""
75+
76+
def _create_hf_dlc_model(self) -> Type[Model]:
77+
"""Placeholder docstring"""
78+
79+
hf_model_md = get_huggingface_model_metadata(self.model,
80+
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
81+
if 'pytorch' in hf_model_md.get("tags"):
82+
self.pytorch_version = "1.8.1"
83+
self.py_version = "py36"
84+
elif 'keras' in hf_model_md.get("tags") or 'tensorflow' in hf_model_md.get("tags"):
85+
self.py_version = "py37"
86+
self.tensorflow_version = "2.4.1"
87+
88+
self.transformers_version = "4.6.1"
89+
90+
pysdk_model = HuggingFaceModel(
91+
env=self.env_vars,
92+
role=self.role_arn,
93+
sagemaker_session=self.sagemaker_session,
94+
py_version=self.py_version,
95+
transformers_version=self.transformers_version,
96+
pytorch_version=self.pytorch_version,
97+
)
98+
self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, "local")
99+
100+
self._original_deploy = pysdk_model.deploy
101+
pysdk_model.deploy = self._hf_dlc_model_builder_deploy_wrapper
102+
return pysdk_model
103+
104+
@_capture_telemetry("hf_dlc.deploy")
105+
def _hf_dlc_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
106+
"""Placeholder docstring"""
107+
timeout = kwargs.get("model_data_download_timeout")
108+
if timeout:
109+
self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
110+
111+
if "mode" in kwargs and kwargs.get("mode") != self.mode:
112+
overwrite_mode = kwargs.get("mode")
113+
# mode overwritten by customer during model.deploy()
114+
logger.warning(
115+
"Deploying in %s Mode, overriding existing configurations set for %s mode",
116+
overwrite_mode,
117+
self.mode,
118+
)
119+
120+
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
121+
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
122+
elif overwrite_mode == Mode.LOCAL_CONTAINER:
123+
self._prepare_for_mode()
124+
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
125+
else:
126+
raise ValueError("Mode %s is not supported!" % overwrite_mode)
127+
128+
serializer = self.schema_builder.input_serializer
129+
deserializer = self.schema_builder._output_deserializer
130+
if self.mode == Mode.LOCAL_CONTAINER:
131+
timeout = kwargs.get("model_data_download_timeout")
132+
133+
predictor = HfDLCLocalModePredictor(
134+
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
135+
)
136+
137+
ram_usage_before = _get_ram_usage_mb()
138+
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
139+
self.image_uri,
140+
timeout if timeout else 1800,
141+
None,
142+
predictor,
143+
self.pysdk_model.env,
144+
jumpstart=False,
145+
)
146+
ram_usage_after = _get_ram_usage_mb()
147+
self.ram_usage_model_load = max(ram_usage_after - ram_usage_before, 0)
148+
149+
return predictor
150+
151+
if "mode" in kwargs:
152+
del kwargs["mode"]
153+
if "role" in kwargs:
154+
self.pysdk_model.role = kwargs.get("role")
155+
del kwargs["role"]
156+
157+
# set model_data to uncompressed s3 dict
158+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
159+
self.env_vars.update(env_vars)
160+
self.pysdk_model.env.update(self.env_vars)
161+
162+
if "endpoint_logging" not in kwargs:
163+
kwargs["endpoint_logging"] = True
164+
165+
if self.nb_instance_type and "instance_type" not in kwargs:
166+
kwargs.update({"instance_type": self.nb_instance_type})
167+
elif not self.nb_instance_type and "instance_type" not in kwargs:
168+
raise ValueError(
169+
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
170+
)
171+
else:
172+
try:
173+
tot_gpus = _get_gpu_info(kwargs.get("instance_type"), self.sagemaker_session)
174+
except Exception: # pylint: disable=W0703
175+
tot_gpus = _get_gpu_info_fallback(kwargs.get("instance_type"))
176+
default_num_shard = _get_default_tensor_parallel_degree(self.hf_model_config, tot_gpus)
177+
self.pysdk_model.env.update(
178+
{
179+
"NUM_SHARD": str(default_num_shard),
180+
"SHARDED": "true" if default_num_shard > 1 else "false",
181+
}
182+
)
183+
184+
if "initial_instance_count" not in kwargs:
185+
kwargs.update({"initial_instance_count": 1})
186+
187+
if "endpoint_logging" not in kwargs:
188+
kwargs["endpoint_logging"] = True
189+
190+
predictor = self._original_deploy(*args, **kwargs)
191+
192+
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "0"})
193+
194+
predictor.serializer = serializer
195+
predictor.deserializer = deserializer
196+
return predictor
197+
198+
def _build_for_hugging_face_dlc(self):
199+
"""Placeholder docstring"""
200+
self.nb_instance_type = _get_nb_instance()
201+
202+
_create_dir_structure(self.model_path)
203+
if not hasattr(self, "pysdk_model"):
204+
self.env_vars.update({"HF_MODEL_ID": self.model})
205+
self.hf_model_config = _get_model_config_properties_from_hf(
206+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
207+
)
208+
209+
self.pysdk_model = self._create_hf_dlc_model()
210+
211+
if self.mode == Mode.LOCAL_CONTAINER:
212+
self._prepare_for_mode()
213+
214+
return self.pysdk_model
215+
216+
def _build_for_hf_dlc(self):
217+
"""Placeholder docstring"""
218+
self.secret_key = None
219+
220+
self.model_server = ModelServer.HuggingFaceDLC
221+
222+
self.pysdk_model = self._build_for_hugging_face_dlc()
223+
return self.pysdk_model

src/sagemaker/serve/builder/model_builder.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.serve.builder.djl_builder import DJL
3535
from sagemaker.serve.builder.tgi_builder import TGI
3636
from sagemaker.serve.builder.jumpstart_builder import JumpStart
37+
from sagemaker.serve.builder.hf_dlc_builder import HuggingFaceDLC
3738
from sagemaker.predictor import Predictor
3839
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
3940
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -53,19 +54,21 @@
5354
from sagemaker.serve.validations.check_image_and_hardware_type import (
5455
validate_image_uri_and_hardware,
5556
)
57+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
5658

5759
logger = logging.getLogger(__name__)
5860

5961
supported_model_server = {
6062
ModelServer.TORCHSERVE,
6163
ModelServer.TRITON,
6264
ModelServer.DJL_SERVING,
65+
ModelServer.HuggingFaceDLC,
6366
}
6467

6568

6669
# pylint: disable=attribute-defined-outside-init
6770
@dataclass
68-
class ModelBuilder(Triton, DJL, JumpStart, TGI):
71+
class ModelBuilder(Triton, DJL, JumpStart, TGI, HuggingFaceDLC):
6972
"""Class that builds a deployable model.
7073
7174
Args:
@@ -125,7 +128,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI):
125128
in order for model builder to build the artifacts correctly (according
126129
to the model server). Possible values for this argument are
127130
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
128-
``TRITON``, and ``TGI``.
131+
``TRITON``, ``TGI``, and ``HuggingFaceDLC``.
129132
130133
"""
131134

@@ -577,12 +580,19 @@ def build(
577580
)
578581

579582
self.serve_settings = self._get_serve_setting()
583+
584+
hf_model_md = get_huggingface_model_metadata(self.model,
585+
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
586+
580587
if isinstance(self.model, str):
581588
if self._is_jumpstart_model_id():
582589
return self._build_for_jumpstart()
583590
if self._is_djl():
584591
return self._build_for_djl()
585-
return self._build_for_tgi()
592+
if hf_model_md.get("pipeline_tag") == "text-generation":
593+
return self._build_for_tgi()
594+
else:
595+
return self._build_for_hf_dlc()
586596

587597
self._build_validations()
588598

src/sagemaker/serve/mode/local_container_mode.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
2020
from sagemaker.serve.model_server.triton.server import LocalTritonServer
2121
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
22+
from sagemaker.serve.model_server.hf_dlc.server import LocalHFDLCServing
2223
from sagemaker.session import Session
2324

2425
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@
3132
)
3233

3334

34-
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing):
35+
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalHFDLCServing):
3536
"""A class that holds methods to deploy model to a container in local environment"""
3637

3738
def __init__(
@@ -128,6 +129,15 @@ def create_server(
128129
jumpstart=jumpstart,
129130
)
130131
self._ping_container = self._tgi_deep_ping
132+
elif self.model_server == ModelServer.HuggingFaceDLC:
133+
self._start_hf_dlc_serving(
134+
client=self.client,
135+
image=image,
136+
model_path=model_path if model_path else self.model_path,
137+
secret_key=secret_key,
138+
env_vars=env_vars if env_vars else self.env_vars,
139+
)
140+
self._ping_container = self._hf_dlc_deep_ping
131141

132142
# allow some time for container to be ready
133143
time.sleep(10)

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe
1313
from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing
1414
from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing
15+
from sagemaker.serve.model_server.hf_dlc.server import SageMakerHFDLCServing
1516

1617
logger = logging.getLogger(__name__)
1718

1819

1920
class SageMakerEndpointMode(
20-
SageMakerTorchServe, SageMakerTritonServer, SageMakerDjlServing, SageMakerTgiServing
21+
SageMakerTorchServe, SageMakerTritonServer, SageMakerDjlServing, SageMakerTgiServing, SageMakerHFDLCServing
2122
):
2223
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
2324

@@ -92,5 +93,12 @@ def prepare(
9293
image=image,
9394
jumpstart=jumpstart,
9495
)
96+
if self.model_server == ModelServer.HuggingFaceDLC:
97+
return self._upload_hf_dlc_artifacts(
98+
model_path=model_path,
99+
sagemaker_session=sagemaker_session,
100+
s3_model_data_url=s3_model_data_url,
101+
image=image,
102+
)
95103

96104
raise ValueError("%s model server is not supported" % self.model_server)

src/sagemaker/serve/model_server/hf_dlc/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Prepare HF DLC Model for Deployment"""
14+
15+
from __future__ import absolute_import
16+
import logging
17+
from pathlib import Path
18+
19+
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
def _create_dir_structure(model_path: str) -> tuple:
25+
"""Create the expected model directory structure for the HF DLC server"""
26+
model_path = Path(model_path)
27+
if not model_path.exists():
28+
model_path.mkdir(parents=True)
29+
elif not model_path.is_dir():
30+
raise ValueError("model_dir is not a valid directory")
31+
32+
code_dir = model_path.joinpath("code")
33+
code_dir.mkdir(exist_ok=True, parents=True)
34+
35+
_check_disk_space(model_path)
36+
_check_docker_disk_usage()
37+
38+
return model_path, code_dir

0 commit comments

Comments
 (0)