Skip to content

Commit 32917ac

Browse files
samrudsjiapinw
authored andcommitted
Feat: Pull latest tei container for sentence similiarity models on HuggingFace hub (aws#4686)
* Update: Pull latest tei container for sentence similiarity models * Fix formatting * Address PR comments * Fix formatting * Fix check * Switch sentence similarity to be deployed on tgi * Fix formatting * Fix formatting * Fix formatting * Fix formatting * Introduce TEI builder with TGI server * Fix formmatting * Add integ test * Fix formatting * Add integ test * Add integ test * Add integ test * Add integ test * Add integ test * Fix formatting * Move to G5 for integ test * Fix formatting * Integ test updates * Integ test updates * Integ test updates * Fix formatting * Integ test updates * Move back to generate for ping * Integ test updates * Integ test updates
1 parent f234b5a commit 32917ac

File tree

6 files changed

+543
-5
lines changed

6 files changed

+543
-5
lines changed

src/sagemaker/serve/builder/model_builder.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
3737
from sagemaker.serve.builder.serve_settings import _ServeSettings
3838
from sagemaker.serve.builder.djl_builder import DJL
39+
from sagemaker.serve.builder.tei_builder import TEI
3940
from sagemaker.serve.builder.tgi_builder import TGI
4041
from sagemaker.serve.builder.fastapi_builder import FastAPIServe
4142
from sagemaker.serve.builder.jumpstart_builder import JumpStart
@@ -97,9 +98,9 @@
9798
}
9899

99100

100-
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901
101+
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705
101102
@dataclass
102-
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, FastAPIServe):
103+
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, TEI, FastAPIServe):
103104
"""Class that builds a deployable model.
104105
105106
Args:
@@ -755,7 +756,7 @@ def build( # pylint: disable=R0911
755756
model_task = self.model_metadata.get("HF_TASK")
756757
if self._is_jumpstart_model_id():
757758
return self._build_for_jumpstart()
758-
if self._is_djl(): # pylint: disable=R1705
759+
if self._is_djl():
759760
return self._build_for_djl()
760761
else:
761762
hf_model_md = get_huggingface_model_metadata(
@@ -766,8 +767,10 @@ def build( # pylint: disable=R0911
766767
model_task = hf_model_md.get("pipeline_tag")
767768
if self.schema_builder is None and model_task is not None:
768769
self._hf_schema_builder_init(model_task)
769-
if model_task == "text-generation": # pylint: disable=R1705
770+
if model_task == "text-generation":
770771
return self._build_for_tgi()
772+
if model_task == "sentence-similarity":
773+
return self._build_for_tei()
771774
elif self._can_fit_on_single_gpu():
772775
return self._build_for_transformers()
773776
elif (
+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds mixin logic to support deployment of Model ID"""
14+
from __future__ import absolute_import
15+
import logging
16+
from typing import Type
17+
from abc import ABC, abstractmethod
18+
19+
from sagemaker import image_uris
20+
from sagemaker.model import Model
21+
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
22+
23+
from sagemaker.huggingface import HuggingFaceModel
24+
from sagemaker.serve.utils.local_hardware import (
25+
_get_nb_instance,
26+
)
27+
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
28+
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
29+
from sagemaker.serve.utils.types import ModelServer
30+
from sagemaker.serve.mode.function_pointers import Mode
31+
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
32+
from sagemaker.base_predictor import PredictorBase
33+
34+
logger = logging.getLogger(__name__)
35+
36+
_CODE_FOLDER = "code"
37+
38+
39+
class TEI(ABC):
40+
"""TEI build logic for ModelBuilder()"""
41+
42+
def __init__(self):
43+
self.model = None
44+
self.serve_settings = None
45+
self.sagemaker_session = None
46+
self.model_path = None
47+
self.dependencies = None
48+
self.modes = None
49+
self.mode = None
50+
self.model_server = None
51+
self.image_uri = None
52+
self._is_custom_image_uri = False
53+
self.image_config = None
54+
self.vpc_config = None
55+
self._original_deploy = None
56+
self.hf_model_config = None
57+
self._default_tensor_parallel_degree = None
58+
self._default_data_type = None
59+
self._default_max_tokens = None
60+
self.pysdk_model = None
61+
self.schema_builder = None
62+
self.env_vars = None
63+
self.nb_instance_type = None
64+
self.ram_usage_model_load = None
65+
self.secret_key = None
66+
self.jumpstart = None
67+
self.role_arn = None
68+
69+
@abstractmethod
70+
def _prepare_for_mode(self):
71+
"""Placeholder docstring"""
72+
73+
@abstractmethod
74+
def _get_client_translators(self):
75+
"""Placeholder docstring"""
76+
77+
def _set_to_tgi(self):
78+
"""Placeholder docstring"""
79+
if self.model_server != ModelServer.TGI:
80+
messaging = (
81+
"HuggingFace Model ID support on model server: "
82+
f"{self.model_server} is not currently supported. "
83+
f"Defaulting to {ModelServer.TGI}"
84+
)
85+
logger.warning(messaging)
86+
self.model_server = ModelServer.TGI
87+
88+
def _create_tei_model(self, **kwargs) -> Type[Model]:
89+
"""Placeholder docstring"""
90+
if self.nb_instance_type and "instance_type" not in kwargs:
91+
kwargs.update({"instance_type": self.nb_instance_type})
92+
93+
if not self.image_uri:
94+
self.image_uri = image_uris.retrieve(
95+
"huggingface-tei",
96+
image_scope="inference",
97+
instance_type=kwargs.get("instance_type"),
98+
region=self.sagemaker_session.boto_region_name,
99+
)
100+
101+
pysdk_model = HuggingFaceModel(
102+
image_uri=self.image_uri,
103+
image_config=self.image_config,
104+
vpc_config=self.vpc_config,
105+
env=self.env_vars,
106+
role=self.role_arn,
107+
sagemaker_session=self.sagemaker_session,
108+
)
109+
110+
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
111+
112+
self._original_deploy = pysdk_model.deploy
113+
pysdk_model.deploy = self._tei_model_builder_deploy_wrapper
114+
return pysdk_model
115+
116+
@_capture_telemetry("tei.deploy")
117+
def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
118+
"""Placeholder docstring"""
119+
timeout = kwargs.get("model_data_download_timeout")
120+
if timeout:
121+
self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
122+
123+
if "mode" in kwargs and kwargs.get("mode") != self.mode:
124+
overwrite_mode = kwargs.get("mode")
125+
# mode overwritten by customer during model.deploy()
126+
logger.warning(
127+
"Deploying in %s Mode, overriding existing configurations set for %s mode",
128+
overwrite_mode,
129+
self.mode,
130+
)
131+
132+
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
133+
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
134+
elif overwrite_mode == Mode.LOCAL_CONTAINER:
135+
self._prepare_for_mode()
136+
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
137+
else:
138+
raise ValueError("Mode %s is not supported!" % overwrite_mode)
139+
140+
serializer = self.schema_builder.input_serializer
141+
deserializer = self.schema_builder._output_deserializer
142+
if self.mode == Mode.LOCAL_CONTAINER:
143+
timeout = kwargs.get("model_data_download_timeout")
144+
145+
predictor = TgiLocalModePredictor(
146+
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
147+
)
148+
149+
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
150+
self.image_uri,
151+
timeout if timeout else 1800,
152+
None,
153+
predictor,
154+
self.pysdk_model.env,
155+
jumpstart=False,
156+
)
157+
158+
return predictor
159+
160+
if "mode" in kwargs:
161+
del kwargs["mode"]
162+
if "role" in kwargs:
163+
self.pysdk_model.role = kwargs.get("role")
164+
del kwargs["role"]
165+
166+
# set model_data to uncompressed s3 dict
167+
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
168+
self.env_vars.update(env_vars)
169+
self.pysdk_model.env.update(self.env_vars)
170+
171+
# if the weights have been cached via local container mode -> set to offline
172+
if str(Mode.LOCAL_CONTAINER) in self.modes:
173+
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
174+
else:
175+
# if has not been built for local container we must use cache
176+
# that hosting has write access to.
177+
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
178+
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"
179+
180+
if "endpoint_logging" not in kwargs:
181+
kwargs["endpoint_logging"] = True
182+
183+
if not self.nb_instance_type and "instance_type" not in kwargs:
184+
raise ValueError(
185+
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
186+
)
187+
188+
if "initial_instance_count" not in kwargs:
189+
kwargs.update({"initial_instance_count": 1})
190+
191+
predictor = self._original_deploy(*args, **kwargs)
192+
193+
predictor.serializer = serializer
194+
predictor.deserializer = deserializer
195+
return predictor
196+
197+
def _build_for_hf_tei(self):
198+
"""Placeholder docstring"""
199+
self.nb_instance_type = _get_nb_instance()
200+
201+
_create_dir_structure(self.model_path)
202+
if not hasattr(self, "pysdk_model"):
203+
self.env_vars.update({"HF_MODEL_ID": self.model})
204+
self.hf_model_config = _get_model_config_properties_from_hf(
205+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
206+
)
207+
208+
self.pysdk_model = self._create_tei_model()
209+
210+
if self.mode == Mode.LOCAL_CONTAINER:
211+
self._prepare_for_mode()
212+
213+
return self.pysdk_model
214+
215+
def _build_for_tei(self):
216+
"""Placeholder docstring"""
217+
self.secret_key = None
218+
219+
self._set_to_tgi()
220+
221+
self.pysdk_model = self._build_for_hf_tei()
222+
return self.pysdk_model
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
17+
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
18+
19+
from tests.integ.sagemaker.serve.constants import (
20+
HF_DIR,
21+
PYTHON_VERSION_IS_NOT_310,
22+
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
23+
)
24+
25+
from tests.integ.timeout import timeout
26+
from tests.integ.utils import cleanup_model_resources
27+
import logging
28+
29+
logger = logging.getLogger(__name__)
30+
31+
sample_input = {
32+
"inputs": "The man worked as a [MASK].",
33+
}
34+
35+
loaded_response = [
36+
{
37+
"score": 0.0974755585193634,
38+
"token": 10533,
39+
"token_str": "carpenter",
40+
"sequence": "the man worked as a carpenter.",
41+
},
42+
{
43+
"score": 0.052383411675691605,
44+
"token": 15610,
45+
"token_str": "waiter",
46+
"sequence": "the man worked as a waiter.",
47+
},
48+
{
49+
"score": 0.04962712526321411,
50+
"token": 13362,
51+
"token_str": "barber",
52+
"sequence": "the man worked as a barber.",
53+
},
54+
{
55+
"score": 0.0378861166536808,
56+
"token": 15893,
57+
"token_str": "mechanic",
58+
"sequence": "the man worked as a mechanic.",
59+
},
60+
{
61+
"score": 0.037680838257074356,
62+
"token": 18968,
63+
"token_str": "salesman",
64+
"sequence": "the man worked as a salesman.",
65+
},
66+
]
67+
68+
69+
@pytest.fixture
70+
def model_input():
71+
return {"inputs": "The man worked as a [MASK]."}
72+
73+
74+
@pytest.fixture
75+
def model_builder_model_schema_builder():
76+
return ModelBuilder(
77+
model_path=HF_DIR,
78+
model="BAAI/bge-m3",
79+
schema_builder=SchemaBuilder(sample_input, loaded_response),
80+
model_metadata={
81+
"HF_TASK": "sentence-similarity",
82+
},
83+
)
84+
85+
86+
@pytest.fixture
87+
def model_builder(request):
88+
return request.getfixturevalue(request.param)
89+
90+
91+
@pytest.mark.skipif(
92+
PYTHON_VERSION_IS_NOT_310,
93+
reason="Testing feature needs latest metadata",
94+
)
95+
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
96+
def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input):
97+
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
98+
caught_ex = None
99+
100+
iam_client = sagemaker_session.boto_session.client("iam")
101+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
102+
103+
model = model_builder.build(
104+
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
105+
)
106+
107+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
108+
try:
109+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
110+
predictor = model.deploy(instance_type="ml.g5.2xlarge", initial_instance_count=1)
111+
predictor.predict(model_input)
112+
assert predictor is not None
113+
except Exception as e:
114+
caught_ex = e
115+
finally:
116+
cleanup_model_resources(
117+
sagemaker_session=model_builder.sagemaker_session,
118+
model_name=model.name,
119+
endpoint_name=model.endpoint_name,
120+
)
121+
if caught_ex:
122+
logger.exception(caught_ex)
123+
assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test"

0 commit comments

Comments
 (0)