Skip to content

Commit 7ee395a

Browse files
committed
feat: Introduce HF DLC to ModelBuilder
1 parent 086c946 commit 7ee395a

18 files changed

+595
-43
lines changed

src/sagemaker/serve/builder/djl_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
)
4747
from sagemaker.serve.model_server.djl_serving.prepare import (
4848
prepare_for_djl_serving,
49-
_create_dir_structure,
5049
)
50+
from sagemaker.serve.utils.global_prepare import _create_dir_structure
5151
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
5252
from sagemaker.serve.utils.types import ModelServer, _DjlEngine
5353
from sagemaker.serve.mode.function_pointers import Mode
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds mixin logic to support deployment of Model ID"""
14+
from __future__ import absolute_import
15+
import logging
16+
from typing import Type
17+
from abc import ABC, abstractmethod
18+
19+
from sagemaker.model import Model
20+
from sagemaker.serve.utils.local_hardware import (
21+
_get_nb_instance,
22+
_get_ram_usage_mb,
23+
_get_gpu_info,
24+
_get_gpu_info_fallback,
25+
)
26+
27+
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
28+
from sagemaker.serve.model_server.djl_serving.utils import (
29+
_get_admissible_tensor_parallel_degrees,
30+
_get_default_tensor_parallel_degree,
31+
)
32+
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
33+
from sagemaker.serve.utils.global_prepare import _create_dir_structure
34+
from sagemaker.serve.utils.predictors import HfDLCLocalModePredictor
35+
from sagemaker.serve.utils.types import ModelServer
36+
from sagemaker.serve.mode.function_pointers import Mode
37+
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
38+
from sagemaker.base_predictor import PredictorBase
39+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
40+
41+
logger = logging.getLogger(__name__)
42+
43+
44+
class HuggingFaceDLC(ABC):
45+
"""HuggingFace DLC build logic for ModelBuilder()"""
46+
47+
def __init__(self):
48+
self.model = None
49+
self.serve_settings = None
50+
self.sagemaker_session = None
51+
self.model_path = None
52+
self.dependencies = None
53+
self.modes = None
54+
self.mode = None
55+
self.model_server = None
56+
self.image_uri = None
57+
self._original_deploy = None
58+
self.hf_model_config = None
59+
self._default_data_type = None
60+
self._default_max_tokens = None
61+
self.pysdk_model = None
62+
self.env_vars = None
63+
self.nb_instance_type = None
64+
self.ram_usage_model_load = None
65+
self.secret_key = None
66+
self.role_arn = None
67+
self.transformers_version = None
68+
self.py_version = None
69+
self.pytorch_version = None
70+
self.tensorflow_version = None
71+
72+
@abstractmethod
73+
def _prepare_for_mode(self):
74+
"""Placeholder docstring"""
75+
76+
def _create_hf_dlc_model(self, **kwargs) -> Type[Model]:
77+
"""Placeholder docstring"""
78+
79+
if self.nb_instance_type and "instance_type" not in kwargs:
80+
kwargs.update({"instance_type": self.nb_instance_type})
81+
elif self.mode == Mode.LOCAL_CONTAINER:
82+
kwargs.update({"instance_type": "local"})
83+
else:
84+
raise ValueError(
85+
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
86+
)
87+
88+
hf_model_md = get_huggingface_model_metadata(self.model,
89+
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
90+
if hf_model_md is None:
91+
raise ValueError("Could not fetch HF metadata")
92+
93+
if 'pytorch' in hf_model_md.get("tags"):
94+
self.pytorch_version = "1.8.1"
95+
self.py_version = "py310"
96+
elif 'keras' in hf_model_md.get("tags") or 'tensorflow' in hf_model_md.get("tags"):
97+
self.py_version = "py37"
98+
self.tensorflow_version = "2.4.1"
99+
100+
self.transformers_version = "4.6.1"
101+
102+
pysdk_model = HuggingFaceModel(
103+
env=self.env_vars,
104+
role=self.role_arn,
105+
sagemaker_session=self.sagemaker_session,
106+
py_version=self.py_version,
107+
transformers_version=self.transformers_version,
108+
pytorch_version=self.pytorch_version,
109+
)
110+
self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name, kwargs.get("instance_type"))
111+
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
112+
113+
self._original_deploy = pysdk_model.deploy
114+
pysdk_model.deploy = self._hf_dlc_model_builder_deploy_wrapper
115+
return pysdk_model
116+
117+
@_capture_telemetry("hf_dlc.deploy")
118+
def _hf_dlc_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
119+
"""Placeholder docstring"""
120+
if "mode" in kwargs and kwargs.get("mode") != self.mode:
121+
overwrite_mode = kwargs.get("mode")
122+
# mode overwritten by customer during model.deploy()
123+
logger.warning(
124+
"Deploying in %s Mode, overriding existing configurations set for %s mode",
125+
overwrite_mode,
126+
self.mode,
127+
)
128+
129+
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
130+
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
131+
elif overwrite_mode == Mode.LOCAL_CONTAINER:
132+
self._prepare_for_mode()
133+
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
134+
else:
135+
raise ValueError("Mode %s is not supported!" % overwrite_mode)
136+
137+
serializer = self.schema_builder.input_serializer
138+
deserializer = self.schema_builder._output_deserializer
139+
if self.mode == Mode.LOCAL_CONTAINER:
140+
timeout = kwargs.get("model_data_download_timeout")
141+
142+
predictor = HfDLCLocalModePredictor(
143+
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
144+
)
145+
146+
ram_usage_before = _get_ram_usage_mb()
147+
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
148+
self.image_uri,
149+
timeout if timeout else 1800,
150+
None,
151+
predictor,
152+
self.pysdk_model.env,
153+
jumpstart=False,
154+
)
155+
ram_usage_after = _get_ram_usage_mb()
156+
self.ram_usage_model_load = max(ram_usage_after - ram_usage_before, 0)
157+
158+
return predictor
159+
160+
if "mode" in kwargs:
161+
del kwargs["mode"]
162+
if "role" in kwargs:
163+
self.pysdk_model.role = kwargs.get("role")
164+
del kwargs["role"]
165+
166+
# set model_data to uncompressed s3 dict
167+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
168+
self.env_vars.update(env_vars)
169+
self.pysdk_model.env.update(self.env_vars)
170+
171+
if "endpoint_logging" not in kwargs:
172+
kwargs["endpoint_logging"] = True
173+
174+
if "initial_instance_count" not in kwargs:
175+
kwargs.update({"initial_instance_count": 1})
176+
177+
predictor = self._original_deploy(*args, **kwargs)
178+
179+
predictor.serializer = serializer
180+
predictor.deserializer = deserializer
181+
return predictor
182+
183+
def _build_for_hugging_face_dlc(self):
184+
"""Placeholder docstring"""
185+
self.nb_instance_type = _get_nb_instance()
186+
187+
_create_dir_structure(self.model_path)
188+
if not hasattr(self, "pysdk_model"):
189+
self.env_vars.update({"HF_MODEL_ID": self.model})
190+
self.hf_model_config = _get_model_config_properties_from_hf(
191+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
192+
)
193+
194+
self.pysdk_model = self._create_hf_dlc_model()
195+
196+
if self.mode == Mode.LOCAL_CONTAINER:
197+
self._prepare_for_mode()
198+
199+
return self.pysdk_model
200+
201+
def _build_for_hf_dlc(self):
202+
"""Placeholder docstring"""
203+
self.secret_key = None
204+
205+
self.model_server = ModelServer.HuggingFaceDLC
206+
207+
self.pysdk_model = self._build_for_hugging_face_dlc()
208+
return self.pysdk_model

src/sagemaker/serve/builder/jumpstart_builder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from sagemaker.model import Model
2121
from sagemaker import model_uris
2222
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
23-
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
23+
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources
24+
from sagemaker.serve.utils.global_prepare import _create_dir_structure
2425
from sagemaker.serve.mode.function_pointers import Mode
2526
from sagemaker.serve.utils.predictors import (
2627
DjlLocalModePredictor,

src/sagemaker/serve/builder/model_builder.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.serve.builder.djl_builder import DJL
3535
from sagemaker.serve.builder.tgi_builder import TGI
3636
from sagemaker.serve.builder.jumpstart_builder import JumpStart
37+
from sagemaker.serve.builder.hf_dlc_builder import HuggingFaceDLC
3738
from sagemaker.predictor import Predictor
3839
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
3940
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -53,19 +54,21 @@
5354
from sagemaker.serve.validations.check_image_and_hardware_type import (
5455
validate_image_uri_and_hardware,
5556
)
57+
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
5658

5759
logger = logging.getLogger(__name__)
5860

5961
supported_model_server = {
6062
ModelServer.TORCHSERVE,
6163
ModelServer.TRITON,
6264
ModelServer.DJL_SERVING,
65+
ModelServer.HuggingFaceDLC,
6366
}
6467

6568

6669
# pylint: disable=attribute-defined-outside-init
6770
@dataclass
68-
class ModelBuilder(Triton, DJL, JumpStart, TGI):
71+
class ModelBuilder(Triton, DJL, JumpStart, TGI, HuggingFaceDLC):
6972
"""Class that builds a deployable model.
7073
7174
Args:
@@ -125,7 +128,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI):
125128
in order for model builder to build the artifacts correctly (according
126129
to the model server). Possible values for this argument are
127130
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
128-
``TRITON``, and ``TGI``.
131+
``TRITON``, ``TGI``, and ``HuggingFaceDLC``.
129132
130133
"""
131134

@@ -535,7 +538,7 @@ def wrapper(*args, **kwargs):
535538
return wrapper
536539

537540
# Model Builder is a class to build the model for deployment.
538-
# It supports three modes of deployment
541+
# It supports two modes of deployment
539542
# 1/ SageMaker Endpoint
540543
# 2/ Local launch with container
541544
def build(
@@ -577,12 +580,19 @@ def build(
577580
)
578581

579582
self.serve_settings = self._get_serve_setting()
583+
584+
hf_model_md = get_huggingface_model_metadata(self.model,
585+
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"))
586+
580587
if isinstance(self.model, str):
581588
if self._is_jumpstart_model_id():
582589
return self._build_for_jumpstart()
583590
if self._is_djl():
584591
return self._build_for_djl()
585-
return self._build_for_tgi()
592+
if hf_model_md.get("pipeline_tag") == "text-generation":
593+
return self._build_for_tgi()
594+
else:
595+
return self._build_for_hf_dlc()
586596

587597
self._build_validations()
588598

src/sagemaker/serve/builder/tgi_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
_get_gpu_info,
4848
_get_gpu_info_fallback,
4949
)
50-
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
50+
from sagemaker.serve.utils.global_prepare import _create_dir_structure
5151
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
5252
from sagemaker.serve.utils.types import ModelServer
5353
from sagemaker.serve.mode.function_pointers import Mode

src/sagemaker/serve/mode/local_container_mode.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
2020
from sagemaker.serve.model_server.triton.server import LocalTritonServer
2121
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
22+
from sagemaker.serve.model_server.hf_dlc.server import LocalHFDLCServing
2223
from sagemaker.session import Session
2324

2425
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@
3132
)
3233

3334

34-
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing):
35+
class LocalContainerMode(LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalHFDLCServing):
3536
"""A class that holds methods to deploy model to a container in local environment"""
3637

3738
def __init__(
@@ -128,6 +129,15 @@ def create_server(
128129
jumpstart=jumpstart,
129130
)
130131
self._ping_container = self._tgi_deep_ping
132+
elif self.model_server == ModelServer.HuggingFaceDLC:
133+
self._start_hf_dlc_serving(
134+
client=self.client,
135+
image=image,
136+
model_path=model_path if model_path else self.model_path,
137+
secret_key=secret_key,
138+
env_vars=env_vars if env_vars else self.env_vars,
139+
)
140+
self._ping_container = self._hf_dlc_deep_ping
131141

132142
# allow some time for container to be ready
133143
time.sleep(10)

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe
1313
from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing
1414
from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing
15+
from sagemaker.serve.model_server.hf_dlc.server import SageMakerHFDLCServing
1516

1617
logger = logging.getLogger(__name__)
1718

1819

1920
class SageMakerEndpointMode(
20-
SageMakerTorchServe, SageMakerTritonServer, SageMakerDjlServing, SageMakerTgiServing
21+
SageMakerTorchServe, SageMakerTritonServer, SageMakerDjlServing, SageMakerTgiServing, SageMakerHFDLCServing
2122
):
2223
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
2324

@@ -92,5 +93,12 @@ def prepare(
9293
image=image,
9394
jumpstart=jumpstart,
9495
)
96+
if self.model_server == ModelServer.HuggingFaceDLC:
97+
return self._upload_hf_dlc_artifacts(
98+
model_path=model_path,
99+
sagemaker_session=sagemaker_session,
100+
s3_model_data_url=s3_model_data_url,
101+
image=image,
102+
)
95103

96104
raise ValueError("%s model server is not supported" % self.model_server)

src/sagemaker/serve/model_server/djl_serving/prepare.py

-17
Original file line numberDiff line numberDiff line change
@@ -157,23 +157,6 @@ def _copy_inference_script(code_dir):
157157
shutil.copy2(inference_file, code_dir)
158158

159159

160-
def _create_dir_structure(model_path: str) -> tuple:
161-
"""Placeholder Docstring"""
162-
model_path = Path(model_path)
163-
if not model_path.exists():
164-
model_path.mkdir(parents=True)
165-
elif not model_path.is_dir():
166-
raise ValueError("model_dir is not a valid directory")
167-
168-
code_dir = model_path.joinpath("code")
169-
code_dir.mkdir(exist_ok=True, parents=True)
170-
171-
_check_disk_space(model_path)
172-
_check_docker_disk_usage()
173-
174-
return (model_path, code_dir)
175-
176-
177160
def prepare_for_djl_serving(
178161
model_path: str,
179162
model: DJLModel,

src/sagemaker/serve/model_server/hf_dlc/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)