diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 013f2bc79b..1431965317 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -36,6 +36,7 @@ from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode from sagemaker.serve.mode.local_container_mode import LocalContainerMode +from sagemaker.serve.mode.in_process_mode import InProcessMode from sagemaker.serve.detector.pickler import save_pkl, save_xgboost from sagemaker.serve.builder.serve_settings import _ServeSettings from sagemaker.serve.builder.djl_builder import DJL @@ -410,7 +411,7 @@ def _prepare_for_mode( ) self.env_vars.update(env_vars_sagemaker) return self.s3_upload_path, env_vars_sagemaker - if self.mode == Mode.LOCAL_CONTAINER: + elif self.mode == Mode.LOCAL_CONTAINER: # init the LocalContainerMode object self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode( inference_spec=self.inference_spec, @@ -422,9 +423,22 @@ def _prepare_for_mode( ) self.modes[str(Mode.LOCAL_CONTAINER)].prepare() return None + elif self.mode == Mode.IN_PROCESS: + # init the InProcessMode object + self.modes[str(Mode.IN_PROCESS)] = InProcessMode( + inference_spec=self.inference_spec, + schema_builder=self.schema_builder, + session=self.sagemaker_session, + model_path=self.model_path, + env_vars=self.env_vars, + model_server=self.model_server, + ) + self.modes[str(Mode.IN_PROCESS)].prepare() + return None raise ValueError( - "Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT) + "Please specify mode in: %s, %s, %s" + % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS) ) def _get_client_translators(self): @@ -606,6 +620,9 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str): elif overwrite_mode == Mode.LOCAL_CONTAINER: self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER self._prepare_for_mode() + elif overwrite_mode == Mode.IN_PROCESS: + self.mode = self.pysdk_model.mode = Mode.IN_PROCESS + self._prepare_for_mode() else: raise ValueError("Mode %s is not supported!" % overwrite_mode) @@ -795,9 +812,10 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None: self.dependencies.update({"requirements": mlflow_model_dependency_path}) # Model Builder is a class to build the model for deployment. - # It supports two modes of deployment + # It supports two* modes of deployment # 1/ SageMaker Endpoint # 2/ Local launch with container + # 3/ In process mode with Transformers server in beta release def build( # pylint: disable=R0911 self, mode: Type[Mode] = None, @@ -895,8 +913,10 @@ def build( # pylint: disable=R0911 def _build_validations(self): """Validations needed for model server overrides, or auto-detection or fallback""" - if self.mode == Mode.IN_PROCESS: - raise ValueError("IN_PROCESS mode is not supported yet!") + if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS: + raise ValueError( + "IN_PROCESS mode is only supported for MMS/Transformers server in beta release." + ) if self.inference_spec and self.model: raise ValueError("Can only set one of the following: model, inference_spec.") diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index e5a616ea4b..e3f1f15cf7 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -36,7 +36,10 @@ ) from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.utils.optimize_utils import _is_optimized -from sagemaker.serve.utils.predictors import TransformersLocalModePredictor +from sagemaker.serve.utils.predictors import ( + TransformersLocalModePredictor, + TransformersInProcessModePredictor, +) from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry @@ -47,6 +50,7 @@ logger = logging.getLogger(__name__) DEFAULT_TIMEOUT = 1800 +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] """Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub @@ -228,6 +232,18 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr ) return predictor + if self.mode == Mode.IN_PROCESS: + timeout = kwargs.get("model_data_download_timeout") + + predictor = TransformersInProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + self._set_instance(kwargs) if "mode" in kwargs: @@ -293,7 +309,7 @@ def _build_transformers_env(self): self.pysdk_model = self._create_transformers_model() - if self.mode == Mode.LOCAL_CONTAINER: + if self.mode in LOCAL_MODES: self._prepare_for_mode() return self.pysdk_model diff --git a/src/sagemaker/serve/mode/in_process_mode.py b/src/sagemaker/serve/mode/in_process_mode.py new file mode 100644 index 0000000000..dc3b4fd74f --- /dev/null +++ b/src/sagemaker/serve/mode/in_process_mode.py @@ -0,0 +1,89 @@ +"""Module that defines the InProcessMode class""" + +from __future__ import absolute_import +from pathlib import Path +import logging +from typing import Dict, Type +import time +from datetime import datetime, timedelta + +from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.utils.exceptions import LocalDeepPingException +from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer +from sagemaker.session import Session + +logger = logging.getLogger(__name__) + +_PING_HEALTH_CHECK_FAIL_MSG = ( + "Ping health check did not pass. " + + "Please increase container_timeout_seconds or review your inference code." +) + + +class InProcessMode( + InProcessMultiModelServer, +): + """A class that holds methods to deploy model to a container in process environment""" + + def __init__( + self, + model_server: ModelServer, + inference_spec: Type[InferenceSpec], + schema_builder: Type[SchemaBuilder], + session: Session, + model_path: str = None, + env_vars: Dict = None, + ): + # pylint: disable=bad-super-call + super().__init__() + + self.inference_spec = inference_spec + self.model_path = model_path + self.env_vars = env_vars + self.session = session + self.schema_builder = schema_builder + self.model_server = model_server + self._ping_container = None + + def load(self, model_path: str = None): + """Loads model path, checks that path exists""" + path = Path(model_path if model_path else self.model_path) + if not path.exists(): + raise ValueError("model_path does not exist") + if not path.is_dir(): + raise ValueError("model_path is not a valid directory") + + return self.inference_spec.load(str(path)) + + def prepare(self): + """Prepares the server""" + + def create_server( + self, + predictor: PredictorBase, + ): + """Creating the server and checking ping health.""" + logger.info("Waiting for model server %s to start up...", self.model_server) + + if self.model_server == ModelServer.MMS: + self._ping_container = self._multi_model_server_deep_ping + + time_limit = datetime.now() + timedelta(seconds=5) + while self._ping_container is not None: + final_pull = datetime.now() > time_limit + + if final_pull: + break + + time.sleep(10) + + healthy, response = self._ping_container(predictor) + if healthy: + logger.debug("Ping health check has passed. Returned %s", str(response)) + break + + if not healthy: + raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index 8586fa85fb..b957186b99 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -20,6 +20,23 @@ logger = logging.getLogger(__name__) +class InProcessMultiModelServer: + """In Process Mode Multi Model server instance""" + + def _start_serving(self): + """Initializes the start of the server""" + return Exception("Not implemented") + + def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): + """Invokes the MMS server by sending POST request""" + return Exception("Not implemented") + + def _multi_model_server_deep_ping(self, predictor: PredictorBase): + """Sends a deep ping to ensure prediction""" + response = None + return (True, response) + + class LocalMultiModelServer: """Local Multi Model server instance""" diff --git a/src/sagemaker/serve/utils/exceptions.py b/src/sagemaker/serve/utils/exceptions.py index 30b22ba869..eb22e8cce2 100644 --- a/src/sagemaker/serve/utils/exceptions.py +++ b/src/sagemaker/serve/utils/exceptions.py @@ -24,6 +24,16 @@ def __init__(self, message): super().__init__(message=message) +class InProcessDeepPingException(ModelBuilderException): + """Raise when in process model serving does not pass the deep ping check""" + + fmt = "Error Message: {message}" + model_builder_error_code = 1 + + def __init__(self, message): + super().__init__(message=message) + + class LocalModelOutOfMemoryException(ModelBuilderException): """Raise when local model serving fails to load the model""" diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index 25a995eb48..be6133e8e1 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -6,6 +6,7 @@ from sagemaker import Session from sagemaker.serve.mode.local_container_mode import LocalContainerMode +from sagemaker.serve.mode.in_process_mode import InProcessMode from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serializers import IdentitySerializer, JSONSerializer from sagemaker.deserializers import BytesDeserializer, JSONDeserializer @@ -209,6 +210,49 @@ def delete_predictor(self): self._mode_obj.destroy_server() +class TransformersInProcessModePredictor(PredictorBase): + """Lightweight Transformers predictor for local deployment""" + + def __init__( + self, + mode_obj: Type[InProcessMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return [ + self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_multi_model_server_serving( + self.serializer.serialize(data), + self.content_type, + self.deserializer.ACCEPT[0], + ) + ), + self.content_type, + ) + ] + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in LOCAL_CONTAINER mode""" + self._mode_obj.destroy_server() + + class TeiLocalModePredictor(PredictorBase): """Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes""" diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 4818b9d8b6..2752e991ff 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -66,11 +66,11 @@ class TestModelBuilder(unittest.TestCase): @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_validation_in_progress_mode_not_supported(self, mock_serveSettings): - builder = ModelBuilder() + def test_validation_in_progress_mode_supported(self, mock_serveSettings): + builder = ModelBuilder(model_server=ModelServer.TORCHSERVE) self.assertRaisesRegex( Exception, - "IN_PROCESS mode is not supported yet!", + "IN_PROCESS mode is only supported for MMS/Transformers server in beta release.", builder.build, Mode.IN_PROCESS, mock_role_arn, diff --git a/tests/unit/sagemaker/serve/mode/test_in_process_mode.py b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py new file mode 100644 index 0000000000..f5890982b9 --- /dev/null +++ b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py @@ -0,0 +1,152 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, Mock + +from sagemaker.serve.mode.in_process_mode import InProcessMode +from sagemaker.serve import SchemaBuilder +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.utils.exceptions import LocalDeepPingException + + +mock_prompt = "Hello, I'm a language model," +mock_response = "Hello, I'm a language model, and I'm here to help you with your English." +mock_sample_input = {"inputs": mock_prompt, "parameters": {}} +mock_sample_output = [{"generated_text": mock_response}] + + +class TestInProcessMode(unittest.TestCase): + + @patch("sagemaker.serve.mode.in_process_mode.Path") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_load_happy(self, mock_session, mock_inference_spec, mock_path): + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + model_server=ModelServer.MMS, + inference_spec=mock_inference_spec, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + env_vars={"key": "val"}, + ) + + res = in_process_mode.load(model_path="/tmp/model-builder/code/") + + self.assertEqual(res, "Dummy load") + self.assertEqual(in_process_mode.inference_spec, mock_inference_spec) + self.assertEqual(in_process_mode.schema_builder, mock_schema_builder) + self.assertEqual(in_process_mode.model_path, "model_path") + self.assertEqual(in_process_mode.env_vars, {"key": "val"}) + + @patch("sagemaker.serve.mode.in_process_mode.Path") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_load_ex(self, mock_session, mock_inference_spec, mock_path): + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + model_server=ModelServer.MMS, + inference_spec=mock_inference_spec, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + ) + + self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") + + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + model_server=ModelServer.MMS, + inference_spec=mock_inference_spec, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + ) + + self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") + + @patch("sagemaker.serve.mode.in_process_mode.logger") + @patch("sagemaker.base_predictor.PredictorBase") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_create_server_happy( + self, mock_session, mock_inference_spec, mock_predictor, mock_logger + ): + mock_response = "Fake response" + mock_multi_model_server_deep_ping = Mock() + mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( + True, + mock_response, + ) + + in_process_mode = InProcessMode( + model_server=ModelServer.MMS, + inference_spec=mock_inference_spec, + schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), + session=mock_session, + model_path="model_path", + ) + + in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping + + in_process_mode.create_server(predictor=mock_predictor) + + mock_logger.info.assert_called_once_with( + "Waiting for model server %s to start up...", ModelServer.MMS + ) + mock_logger.debug.assert_called_once_with( + "Ping health check has passed. Returned %s", str(mock_response) + ) + + @patch("sagemaker.base_predictor.PredictorBase") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_create_server_ex( + self, + mock_session, + mock_inference_spec, + mock_predictor, + ): + mock_multi_model_server_deep_ping = Mock() + mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( + False, + None, + ) + + in_process_mode = InProcessMode( + model_server=ModelServer.MMS, + inference_spec=mock_inference_spec, + schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), + session=mock_session, + model_path="model_path", + ) + + in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping + + self.assertRaises(LocalDeepPingException, in_process_mode.create_server, mock_predictor)