Skip to content

Commit a870e19

Browse files
bryannahm1Bryannah Hernandezsage-maker
authored
feat: Support for ModelBuilder In_Process Mode (1/2) (#4784)
* InferenceSpec support for HF * feat: InferenceSpec support for MMS and testing * Introduce changes for InProcess Mode * mb_inprocess updates * In_Process mode for TGI transformers, edits * Remove InfSpec from branch * changes to support in_process * changes to get pre-checks passing * pylint fix * unit test, test mb * period missing, added * suggestions and test added * pre-push fix * missing an @ * fixes to test, added stubbing * removing for fixes * variable fixes * init fix * tests for in process mode * prepush fix * minor fix --------- Co-authored-by: Bryannah Hernandez <[email protected]> Co-authored-by: sage-maker <[email protected]>
1 parent 0a40680 commit a870e19

File tree

8 files changed

+358
-10
lines changed

8 files changed

+358
-10
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sagemaker.serve.mode.function_pointers import Mode
3737
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
3838
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
39+
from sagemaker.serve.mode.in_process_mode import InProcessMode
3940
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
4041
from sagemaker.serve.builder.serve_settings import _ServeSettings
4142
from sagemaker.serve.builder.djl_builder import DJL
@@ -410,7 +411,7 @@ def _prepare_for_mode(
410411
)
411412
self.env_vars.update(env_vars_sagemaker)
412413
return self.s3_upload_path, env_vars_sagemaker
413-
if self.mode == Mode.LOCAL_CONTAINER:
414+
elif self.mode == Mode.LOCAL_CONTAINER:
414415
# init the LocalContainerMode object
415416
self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode(
416417
inference_spec=self.inference_spec,
@@ -422,9 +423,22 @@ def _prepare_for_mode(
422423
)
423424
self.modes[str(Mode.LOCAL_CONTAINER)].prepare()
424425
return None
426+
elif self.mode == Mode.IN_PROCESS:
427+
# init the InProcessMode object
428+
self.modes[str(Mode.IN_PROCESS)] = InProcessMode(
429+
inference_spec=self.inference_spec,
430+
schema_builder=self.schema_builder,
431+
session=self.sagemaker_session,
432+
model_path=self.model_path,
433+
env_vars=self.env_vars,
434+
model_server=self.model_server,
435+
)
436+
self.modes[str(Mode.IN_PROCESS)].prepare()
437+
return None
425438

426439
raise ValueError(
427-
"Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT)
440+
"Please specify mode in: %s, %s, %s"
441+
% (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS)
428442
)
429443

430444
def _get_client_translators(self):
@@ -606,6 +620,9 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
606620
elif overwrite_mode == Mode.LOCAL_CONTAINER:
607621
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
608622
self._prepare_for_mode()
623+
elif overwrite_mode == Mode.IN_PROCESS:
624+
self.mode = self.pysdk_model.mode = Mode.IN_PROCESS
625+
self._prepare_for_mode()
609626
else:
610627
raise ValueError("Mode %s is not supported!" % overwrite_mode)
611628

@@ -795,9 +812,10 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
795812
self.dependencies.update({"requirements": mlflow_model_dependency_path})
796813

797814
# Model Builder is a class to build the model for deployment.
798-
# It supports two modes of deployment
815+
# It supports two* modes of deployment
799816
# 1/ SageMaker Endpoint
800817
# 2/ Local launch with container
818+
# 3/ In process mode with Transformers server in beta release
801819
def build( # pylint: disable=R0911
802820
self,
803821
mode: Type[Mode] = None,
@@ -895,8 +913,10 @@ def build( # pylint: disable=R0911
895913

896914
def _build_validations(self):
897915
"""Validations needed for model server overrides, or auto-detection or fallback"""
898-
if self.mode == Mode.IN_PROCESS:
899-
raise ValueError("IN_PROCESS mode is not supported yet!")
916+
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS:
917+
raise ValueError(
918+
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release."
919+
)
900920

901921
if self.inference_spec and self.model:
902922
raise ValueError("Can only set one of the following: model, inference_spec.")

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
)
3737
from sagemaker.serve.detector.pickler import save_pkl
3838
from sagemaker.serve.utils.optimize_utils import _is_optimized
39-
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
39+
from sagemaker.serve.utils.predictors import (
40+
TransformersLocalModePredictor,
41+
TransformersInProcessModePredictor,
42+
)
4043
from sagemaker.serve.utils.types import ModelServer
4144
from sagemaker.serve.mode.function_pointers import Mode
4245
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
@@ -47,6 +50,7 @@
4750

4851
logger = logging.getLogger(__name__)
4952
DEFAULT_TIMEOUT = 1800
53+
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]
5054

5155

5256
"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub
@@ -228,6 +232,18 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
228232
)
229233
return predictor
230234

235+
if self.mode == Mode.IN_PROCESS:
236+
timeout = kwargs.get("model_data_download_timeout")
237+
238+
predictor = TransformersInProcessModePredictor(
239+
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
240+
)
241+
242+
self.modes[str(Mode.IN_PROCESS)].create_server(
243+
predictor,
244+
)
245+
return predictor
246+
231247
self._set_instance(kwargs)
232248

233249
if "mode" in kwargs:
@@ -293,7 +309,7 @@ def _build_transformers_env(self):
293309

294310
self.pysdk_model = self._create_transformers_model()
295311

296-
if self.mode == Mode.LOCAL_CONTAINER:
312+
if self.mode in LOCAL_MODES:
297313
self._prepare_for_mode()
298314

299315
return self.pysdk_model
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Module that defines the InProcessMode class"""
2+
3+
from __future__ import absolute_import
4+
from pathlib import Path
5+
import logging
6+
from typing import Dict, Type
7+
import time
8+
from datetime import datetime, timedelta
9+
10+
from sagemaker.base_predictor import PredictorBase
11+
from sagemaker.serve.spec.inference_spec import InferenceSpec
12+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
13+
from sagemaker.serve.utils.types import ModelServer
14+
from sagemaker.serve.utils.exceptions import LocalDeepPingException
15+
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
16+
from sagemaker.session import Session
17+
18+
logger = logging.getLogger(__name__)
19+
20+
_PING_HEALTH_CHECK_FAIL_MSG = (
21+
"Ping health check did not pass. "
22+
+ "Please increase container_timeout_seconds or review your inference code."
23+
)
24+
25+
26+
class InProcessMode(
27+
InProcessMultiModelServer,
28+
):
29+
"""A class that holds methods to deploy model to a container in process environment"""
30+
31+
def __init__(
32+
self,
33+
model_server: ModelServer,
34+
inference_spec: Type[InferenceSpec],
35+
schema_builder: Type[SchemaBuilder],
36+
session: Session,
37+
model_path: str = None,
38+
env_vars: Dict = None,
39+
):
40+
# pylint: disable=bad-super-call
41+
super().__init__()
42+
43+
self.inference_spec = inference_spec
44+
self.model_path = model_path
45+
self.env_vars = env_vars
46+
self.session = session
47+
self.schema_builder = schema_builder
48+
self.model_server = model_server
49+
self._ping_container = None
50+
51+
def load(self, model_path: str = None):
52+
"""Loads model path, checks that path exists"""
53+
path = Path(model_path if model_path else self.model_path)
54+
if not path.exists():
55+
raise ValueError("model_path does not exist")
56+
if not path.is_dir():
57+
raise ValueError("model_path is not a valid directory")
58+
59+
return self.inference_spec.load(str(path))
60+
61+
def prepare(self):
62+
"""Prepares the server"""
63+
64+
def create_server(
65+
self,
66+
predictor: PredictorBase,
67+
):
68+
"""Creating the server and checking ping health."""
69+
logger.info("Waiting for model server %s to start up...", self.model_server)
70+
71+
if self.model_server == ModelServer.MMS:
72+
self._ping_container = self._multi_model_server_deep_ping
73+
74+
time_limit = datetime.now() + timedelta(seconds=5)
75+
while self._ping_container is not None:
76+
final_pull = datetime.now() > time_limit
77+
78+
if final_pull:
79+
break
80+
81+
time.sleep(10)
82+
83+
healthy, response = self._ping_container(predictor)
84+
if healthy:
85+
logger.debug("Ping health check has passed. Returned %s", str(response))
86+
break
87+
88+
if not healthy:
89+
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23+
class InProcessMultiModelServer:
24+
"""In Process Mode Multi Model server instance"""
25+
26+
def _start_serving(self):
27+
"""Initializes the start of the server"""
28+
return Exception("Not implemented")
29+
30+
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
31+
"""Invokes the MMS server by sending POST request"""
32+
return Exception("Not implemented")
33+
34+
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
35+
"""Sends a deep ping to ensure prediction"""
36+
response = None
37+
return (True, response)
38+
39+
2340
class LocalMultiModelServer:
2441
"""Local Multi Model server instance"""
2542

src/sagemaker/serve/utils/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def __init__(self, message):
2424
super().__init__(message=message)
2525

2626

27+
class InProcessDeepPingException(ModelBuilderException):
28+
"""Raise when in process model serving does not pass the deep ping check"""
29+
30+
fmt = "Error Message: {message}"
31+
model_builder_error_code = 1
32+
33+
def __init__(self, message):
34+
super().__init__(message=message)
35+
36+
2737
class LocalModelOutOfMemoryException(ModelBuilderException):
2838
"""Raise when local model serving fails to load the model"""
2939

src/sagemaker/serve/utils/predictors.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from sagemaker import Session
88
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
9+
from sagemaker.serve.mode.in_process_mode import InProcessMode
910
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1011
from sagemaker.serializers import IdentitySerializer, JSONSerializer
1112
from sagemaker.deserializers import BytesDeserializer, JSONDeserializer
@@ -209,6 +210,49 @@ def delete_predictor(self):
209210
self._mode_obj.destroy_server()
210211

211212

213+
class TransformersInProcessModePredictor(PredictorBase):
214+
"""Lightweight Transformers predictor for local deployment"""
215+
216+
def __init__(
217+
self,
218+
mode_obj: Type[InProcessMode],
219+
serializer=JSONSerializer(),
220+
deserializer=JSONDeserializer(),
221+
):
222+
self._mode_obj = mode_obj
223+
self.serializer = serializer
224+
self.deserializer = deserializer
225+
226+
def predict(self, data):
227+
"""Placeholder docstring"""
228+
return [
229+
self.deserializer.deserialize(
230+
io.BytesIO(
231+
self._mode_obj._invoke_multi_model_server_serving(
232+
self.serializer.serialize(data),
233+
self.content_type,
234+
self.deserializer.ACCEPT[0],
235+
)
236+
),
237+
self.content_type,
238+
)
239+
]
240+
241+
@property
242+
def content_type(self):
243+
"""The MIME type of the data sent to the inference endpoint."""
244+
return self.serializer.CONTENT_TYPE
245+
246+
@property
247+
def accept(self):
248+
"""The content type(s) that are expected from the inference endpoint."""
249+
return self.deserializer.ACCEPT
250+
251+
def delete_predictor(self):
252+
"""Shut down and remove the container that you created in LOCAL_CONTAINER mode"""
253+
self._mode_obj.destroy_server()
254+
255+
212256
class TeiLocalModePredictor(PredictorBase):
213257
"""Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes"""
214258

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@
6666

6767
class TestModelBuilder(unittest.TestCase):
6868
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
69-
def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
70-
builder = ModelBuilder()
69+
def test_validation_in_progress_mode_supported(self, mock_serveSettings):
70+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
7171
self.assertRaisesRegex(
7272
Exception,
73-
"IN_PROCESS mode is not supported yet!",
73+
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release.",
7474
builder.build,
7575
Mode.IN_PROCESS,
7676
mock_role_arn,

0 commit comments

Comments
 (0)