Skip to content

Commit 0847a16

Browse files
pravali96pintaoz-aws
authored andcommitted
Add in_process mode support for DJL and TorchServe servers (#1570)
* add in-process mode for DJL server * fix format * add inference_spec as a member of DJL * add the validations for model server * fix typo * fix test assertion * add unit-testing * have a common server for inprocess mode * fix failing tests * add support to torchserve * fix tests to include torchserve servers * use custom inference_spec code instead of HF pipelines * fix tests for app.py * fix unit test failure * fix format * use schema_builder for serialization and deserialization * remove task field * remove unused import
1 parent fff8cdd commit 0847a16

File tree

12 files changed

+257
-134
lines changed

12 files changed

+257
-134
lines changed

src/sagemaker/serve/builder/djl_builder.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@
4747
from sagemaker.serve.model_server.djl_serving.prepare import (
4848
_create_dir_structure,
4949
)
50-
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
50+
from sagemaker.serve.utils.predictors import InProcessModePredictor, DjlLocalModePredictor
5151
from sagemaker.serve.utils.types import ModelServer
5252
from sagemaker.serve.mode.function_pointers import Mode
5353
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
5454
from sagemaker.djl_inference.model import DJLModel
5555
from sagemaker.base_predictor import PredictorBase
5656

5757
logger = logging.getLogger(__name__)
58+
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]
5859

5960
# Match JumpStart DJL entrypoint format
6061
_CODE_FOLDER = "code"
@@ -77,6 +78,7 @@ def __init__(self):
7778
self.mode = None
7879
self.model_server = None
7980
self.image_uri = None
81+
self.inference_spec = None
8082
self._is_custom_image_uri = False
8183
self.image_config = None
8284
self.vpc_config = None
@@ -96,11 +98,11 @@ def __init__(self):
9698

9799
@abstractmethod
98100
def _prepare_for_mode(self):
99-
"""Placeholder docstring"""
101+
"""Abstract method"""
100102

101103
@abstractmethod
102104
def _get_client_translators(self):
103-
"""Placeholder docstring"""
105+
"""Abstract method"""
104106

105107
def _is_djl(self):
106108
"""Placeholder docstring"""
@@ -146,7 +148,7 @@ def _create_djl_model(self) -> Type[Model]:
146148

147149
@_capture_telemetry("djl.deploy")
148150
def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
149-
"""Placeholder docstring"""
151+
"""Returns predictor depending on local mode or endpoint mode"""
150152
timeout = kwargs.get("model_data_download_timeout")
151153
if timeout:
152154
self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
@@ -189,6 +191,18 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
189191

190192
serializer = self.schema_builder.input_serializer
191193
deserializer = self.schema_builder._output_deserializer
194+
195+
if self.mode == Mode.IN_PROCESS:
196+
197+
predictor = InProcessModePredictor(
198+
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
199+
)
200+
201+
self.modes[str(Mode.IN_PROCESS)].create_server(
202+
predictor,
203+
)
204+
return predictor
205+
192206
if self.mode == Mode.LOCAL_CONTAINER:
193207
timeout = kwargs.get("model_data_download_timeout")
194208

@@ -249,9 +263,15 @@ def _build_for_hf_djl(self):
249263

250264
_create_dir_structure(self.model_path)
251265
if not hasattr(self, "pysdk_model"):
252-
self.env_vars.update({"HF_MODEL_ID": self.model})
266+
if self.inference_spec is not None:
267+
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
268+
else:
269+
self.env_vars.update({"HF_MODEL_ID": self.model})
270+
271+
logger.info(self.env_vars)
272+
253273
self.hf_model_config = _get_model_config_properties_from_hf(
254-
self.model, self.env_vars.get("HF_TOKEN")
274+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_TOKEN")
255275
)
256276
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
257277
self.model, self.hf_model_config, self.schema_builder
@@ -260,9 +280,10 @@ def _build_for_hf_djl(self):
260280
self.schema_builder.sample_input["parameters"][
261281
"max_new_tokens"
262282
] = _default_max_new_tokens
283+
263284
self.pysdk_model = self._create_djl_model()
264285

265-
if self.mode == Mode.LOCAL_CONTAINER:
286+
if self.mode in LOCAL_MODES:
266287
self._prepare_for_mode()
267288

268289
return self.pysdk_model
@@ -451,7 +472,6 @@ def _build_for_djl(self):
451472
"""Placeholder docstring"""
452473
self._validate_djl_serving_sample_data()
453474
self.secret_key = None
454-
455475
self.pysdk_model = self._build_for_hf_djl()
456476
self.pysdk_model.tune = self._tune_for_hf_djl
457477
if self.role_arn:

src/sagemaker/serve/builder/model_builder.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
_extract_speculative_draft_model_provider,
8282
_jumpstart_speculative_decoding,
8383
)
84-
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
84+
from sagemaker.serve.utils.predictors import _get_local_mode_predictor, InProcessModePredictor
8585
from sagemaker.serve.utils.hardware_detector import (
8686
_get_gpu_info,
8787
_get_gpu_info_fallback,
@@ -566,6 +566,18 @@ def _model_builder_deploy_wrapper(
566566
if mode and mode != self.mode:
567567
self._overwrite_mode_in_deploy(overwrite_mode=mode)
568568

569+
if self.mode == Mode.IN_PROCESS:
570+
serializer, deserializer = self._get_client_translators()
571+
572+
predictor = InProcessModePredictor(
573+
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
574+
)
575+
576+
self.modes[str(Mode.IN_PROCESS)].create_server(
577+
predictor,
578+
)
579+
return predictor
580+
569581
if self.mode == Mode.LOCAL_CONTAINER:
570582
serializer, deserializer = self._get_client_translators()
571583
predictor = _get_local_mode_predictor(
@@ -919,11 +931,16 @@ def build( # pylint: disable=R0911
919931

920932
def _build_validations(self):
921933
"""Validations needed for model server overrides, or auto-detection or fallback"""
922-
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS:
934+
if (
935+
self.mode == Mode.IN_PROCESS
936+
and self.model_server is not ModelServer.MMS
937+
and self.model_server is not ModelServer.DJL_SERVING
938+
and self.model_server is not ModelServer.TORCHSERVE
939+
):
923940
raise ValueError(
924-
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release."
941+
"IN_PROCESS mode is only supported for the following servers "
942+
"in beta release: MMS/Transformers, TORCHSERVE, DJL_SERVING server"
925943
)
926-
927944
if self.inference_spec and self.model:
928945
raise ValueError("Can only set one of the following: model, inference_spec.")
929946

src/sagemaker/serve/builder/transformers_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sagemaker.serve.utils.optimize_utils import _is_optimized
3939
from sagemaker.serve.utils.predictors import (
4040
TransformersLocalModePredictor,
41-
TransformersInProcessModePredictor,
41+
InProcessModePredictor,
4242
)
4343
from sagemaker.serve.utils.types import ModelServer
4444
from sagemaker.serve.mode.function_pointers import Mode
@@ -237,7 +237,7 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
237237
if self.mode == Mode.IN_PROCESS:
238238
timeout = kwargs.get("model_data_download_timeout")
239239

240-
predictor = TransformersInProcessModePredictor(
240+
predictor = InProcessModePredictor(
241241
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
242242
)
243243

src/sagemaker/serve/mode/in_process_mode.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,15 @@
1313
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1414
from sagemaker.serve.utils.types import ModelServer
1515
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
16-
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
16+
from sagemaker.serve.model_server.in_process_model_server.in_process_server import InProcessServing
1717
from sagemaker.session import Session
1818

1919
logger = logging.getLogger(__name__)
2020

21-
_PING_HEALTH_CHECK_FAIL_MSG = (
22-
"Ping health check did not pass. "
23-
+ "Please increase container_timeout_seconds or review your inference code."
24-
)
21+
_PING_HEALTH_CHECK_FAIL_MSG = "Ping health check did not pass. Please review your inference code."
2522

2623

27-
class InProcessMode(
28-
InProcessMultiModelServer,
29-
):
24+
class InProcessMode(InProcessServing):
3025
"""A class that holds methods to deploy model to a container in process environment"""
3126

3227
def __init__(
@@ -70,7 +65,13 @@ def create_server(
7065
logger.info("Waiting for model server %s to start up...", self.model_server)
7166

7267
if self.model_server == ModelServer.MMS:
73-
self._ping_local_server = self._multi_model_server_deep_ping
68+
self._ping_local_server = self._deep_ping
69+
self._start_serving()
70+
elif self.model_server == ModelServer.DJL_SERVING:
71+
self._ping_local_server = self._deep_ping
72+
self._start_serving()
73+
elif self.model_server == ModelServer.TORCHSERVE:
74+
self._ping_local_server = self._deep_ping
7475
self._start_serving()
7576

7677
# allow some time for server to be ready.

src/sagemaker/serve/app.py renamed to src/sagemaker/serve/model_server/in_process_model_server/app.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
from __future__ import absolute_import
44

55
import asyncio
6+
import io
67
import logging
78
import threading
89
from typing import Optional
910

11+
from sagemaker.serve.spec.inference_spec import InferenceSpec
12+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1013

1114
logger = logging.getLogger(__name__)
1215

@@ -17,45 +20,44 @@
1720
logger.error("Unable to import uvicorn, check if uvicorn is installed.")
1821

1922

20-
try:
21-
from transformers import pipeline
22-
except ImportError:
23-
logger.error("Unable to import transformers, check if transformers is installed.")
24-
25-
2623
try:
2724
from fastapi import FastAPI, Request, APIRouter
2825
except ImportError:
2926
logger.error("Unable to import fastapi, check if fastapi is installed.")
3027

3128

3229
class InProcessServer:
33-
"""Placeholder docstring"""
30+
"""Generic In-Process Server for Serving Models using InferenceSpec"""
3431

35-
def __init__(self, model_id: Optional[str] = None, task: Optional[str] = None):
32+
def __init__(
33+
self,
34+
inference_spec: Optional[InferenceSpec] = None,
35+
schema_builder: Optional[SchemaBuilder] = None,
36+
):
3637
self._thread = None
3738
self._loop = None
3839
self._stop_event = asyncio.Event()
3940
self._router = APIRouter()
40-
self._model_id = model_id
41-
self._task = task
4241
self.server = None
4342
self.port = None
4443
self.host = None
45-
# TODO: Pick up device automatically.
46-
self._generator = pipeline(task, model=model_id, device="cpu")
47-
48-
# pylint: disable=unused-variable
49-
@self._router.post("/generate")
50-
async def generate_text(prompt: Request):
51-
"""Placeholder docstring"""
52-
str_prompt = await prompt.json()
53-
str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt
54-
55-
generated_text = self._generator(
56-
str_prompt, max_length=30, num_return_sequences=1, truncation=True
44+
self.inference_spec = inference_spec
45+
self.schema_builder = schema_builder
46+
self._load_model = self.inference_spec.load(model_dir=None)
47+
48+
@self._router.post("/invoke")
49+
async def invoke(request: Request):
50+
"""Generate text based on the provided prompt"""
51+
52+
request_header = request.headers
53+
request_body = await request.body()
54+
content_type = request_header.get("Content-Type", None)
55+
input_data = schema_builder.input_deserializer.deserialize(
56+
io.BytesIO(request_body), content_type[0]
5757
)
58-
return generated_text
58+
logger.debug(f"Received request: {input_data}")
59+
response = self.inference_spec.invoke(input_data, self._load_model)
60+
return response
5961

6062
self._create_server()
6163

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Module for In_process Serving"""
2+
3+
from __future__ import absolute_import
4+
5+
import requests
6+
import logging
7+
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
8+
from sagemaker.base_predictor import PredictorBase
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class InProcessServing:
14+
"""In Process Mode server instance"""
15+
16+
def _start_serving(self):
17+
"""Initializes the start of the server"""
18+
from sagemaker.serve.model_server.in_process_model_server.app import InProcessServer
19+
20+
self.server = InProcessServer(
21+
inference_spec=self.inference_spec, schema_builder=self.schema_builder
22+
)
23+
self.server.start_server()
24+
25+
def _stop_serving(self):
26+
"""Stops the server"""
27+
self.server.stop_server()
28+
29+
def _invoke_serving(self, request: object, content_type: str, accept: str):
30+
"""Placeholder docstring"""
31+
try:
32+
response = requests.post(
33+
f"http://{self.server.host}:{self.server.port}/invoke",
34+
data=request,
35+
headers={"Content-Type": content_type, "Accept": accept},
36+
timeout=600,
37+
)
38+
response.raise_for_status()
39+
40+
return response.content
41+
except Exception as e:
42+
if "Connection refused" in str(e):
43+
raise Exception(
44+
"Unable to send request to the local server: Connection refused."
45+
) from e
46+
raise Exception("Unable to send request to the local container server %s", str(e))
47+
48+
def _deep_ping(self, predictor: PredictorBase):
49+
"""Sends a deep ping to ensure prediction"""
50+
healthy = False
51+
response = None
52+
try:
53+
response = predictor.predict(self.schema_builder.sample_input)
54+
healthy = response is not None
55+
# pylint: disable=broad-except
56+
except Exception as e:
57+
if "422 Client Error: Unprocessable Entity for url" in str(e):
58+
raise LocalModelInvocationException(str(e))
59+
60+
return healthy, response

0 commit comments

Comments
 (0)