Skip to content

feat: Support for ModelBuilder In_Process Mode (1/2) #4784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
from sagemaker.serve.mode.in_process_mode import InProcessMode
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
from sagemaker.serve.builder.serve_settings import _ServeSettings
from sagemaker.serve.builder.djl_builder import DJL
Expand Down Expand Up @@ -410,7 +411,7 @@ def _prepare_for_mode(
)
self.env_vars.update(env_vars_sagemaker)
return self.s3_upload_path, env_vars_sagemaker
if self.mode == Mode.LOCAL_CONTAINER:
elif self.mode == Mode.LOCAL_CONTAINER:
# init the LocalContainerMode object
self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode(
inference_spec=self.inference_spec,
Expand All @@ -422,9 +423,22 @@ def _prepare_for_mode(
)
self.modes[str(Mode.LOCAL_CONTAINER)].prepare()
return None
elif self.mode == Mode.IN_PROCESS:
# init the InProcessMode object
self.modes[str(Mode.IN_PROCESS)] = InProcessMode(
inference_spec=self.inference_spec,
schema_builder=self.schema_builder,
session=self.sagemaker_session,
model_path=self.model_path,
env_vars=self.env_vars,
model_server=self.model_server,
)
self.modes[str(Mode.IN_PROCESS)].prepare()
return None

raise ValueError(
"Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT)
"Please specify mode in: %s, %s, %s"
% (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS)
)

def _get_client_translators(self):
Expand Down Expand Up @@ -606,6 +620,9 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
self._prepare_for_mode()
elif overwrite_mode == Mode.IN_PROCESS:
self.mode = self.pysdk_model.mode = Mode.IN_PROCESS
self._prepare_for_mode()
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)

Expand Down Expand Up @@ -795,9 +812,10 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
self.dependencies.update({"requirements": mlflow_model_dependency_path})

# Model Builder is a class to build the model for deployment.
# It supports two modes of deployment
# It supports two* modes of deployment
# 1/ SageMaker Endpoint
# 2/ Local launch with container
# 3/ In process mode with Transformers server in beta release
def build( # pylint: disable=R0911
self,
mode: Type[Mode] = None,
Expand Down Expand Up @@ -895,8 +913,10 @@ def build( # pylint: disable=R0911

def _build_validations(self):
"""Validations needed for model server overrides, or auto-detection or fallback"""
if self.mode == Mode.IN_PROCESS:
raise ValueError("IN_PROCESS mode is not supported yet!")
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS:
raise ValueError(
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release."
)

if self.inference_spec and self.model:
raise ValueError("Can only set one of the following: model, inference_spec.")
Expand Down
20 changes: 18 additions & 2 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
)
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.utils.optimize_utils import _is_optimized
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
from sagemaker.serve.utils.predictors import (
TransformersLocalModePredictor,
TransformersInProcessModePredictor,
)
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
Expand All @@ -47,6 +50,7 @@

logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 1800
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]


"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub
Expand Down Expand Up @@ -228,6 +232,18 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
)
return predictor

if self.mode == Mode.IN_PROCESS:
timeout = kwargs.get("model_data_download_timeout")

predictor = TransformersInProcessModePredictor(
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
)

self.modes[str(Mode.IN_PROCESS)].create_server(
predictor,
)
return predictor

self._set_instance(kwargs)

if "mode" in kwargs:
Expand Down Expand Up @@ -293,7 +309,7 @@ def _build_transformers_env(self):

self.pysdk_model = self._create_transformers_model()

if self.mode == Mode.LOCAL_CONTAINER:
if self.mode in LOCAL_MODES:
self._prepare_for_mode()

return self.pysdk_model
Expand Down
89 changes: 89 additions & 0 deletions src/sagemaker/serve/mode/in_process_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Module that defines the InProcessMode class"""

from __future__ import absolute_import
from pathlib import Path
import logging
from typing import Dict, Type
import time
from datetime import datetime, timedelta

from sagemaker.base_predictor import PredictorBase
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.utils.exceptions import LocalDeepPingException
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
from sagemaker.session import Session

logger = logging.getLogger(__name__)

_PING_HEALTH_CHECK_FAIL_MSG = (
"Ping health check did not pass. "
+ "Please increase container_timeout_seconds or review your inference code."
)


class InProcessMode(
InProcessMultiModelServer,
):
"""A class that holds methods to deploy model to a container in process environment"""

def __init__(
self,
model_server: ModelServer,
inference_spec: Type[InferenceSpec],
schema_builder: Type[SchemaBuilder],
session: Session,
model_path: str = None,
env_vars: Dict = None,
):
# pylint: disable=bad-super-call
super().__init__()

self.inference_spec = inference_spec
self.model_path = model_path
self.env_vars = env_vars
self.session = session
self.schema_builder = schema_builder
self.model_server = model_server
self._ping_container = None

def load(self, model_path: str = None):
"""Loads model path, checks that path exists"""
path = Path(model_path if model_path else self.model_path)
if not path.exists():
raise ValueError("model_path does not exist")
if not path.is_dir():
raise ValueError("model_path is not a valid directory")

return self.inference_spec.load(str(path))

def prepare(self):
"""Prepares the server"""

def create_server(
self,
predictor: PredictorBase,
):
"""Creating the server and checking ping health."""
logger.info("Waiting for model server %s to start up...", self.model_server)

if self.model_server == ModelServer.MMS:
self._ping_container = self._multi_model_server_deep_ping

time_limit = datetime.now() + timedelta(seconds=5)
while self._ping_container is not None:
final_pull = datetime.now() > time_limit

if final_pull:
break

time.sleep(10)

healthy, response = self._ping_container(predictor)
if healthy:
logger.debug("Ping health check has passed. Returned %s", str(response))
break

if not healthy:
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
17 changes: 17 additions & 0 deletions src/sagemaker/serve/model_server/multi_model_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@
logger = logging.getLogger(__name__)


class InProcessMultiModelServer:
"""In Process Mode Multi Model server instance"""

def _start_serving(self):
"""Initializes the start of the server"""
return Exception("Not implemented")

def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave these methods as stubs .... return an Exception("Not implemented")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have stubbed it, thank you.

"""Invokes the MMS server by sending POST request"""
return Exception("Not implemented")

def _multi_model_server_deep_ping(self, predictor: PredictorBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this complete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have stubbed it.

"""Sends a deep ping to ensure prediction"""
response = None
return (True, response)


class LocalMultiModelServer:
"""Local Multi Model server instance"""

Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/serve/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def __init__(self, message):
super().__init__(message=message)


class InProcessDeepPingException(ModelBuilderException):
"""Raise when in process model serving does not pass the deep ping check"""

fmt = "Error Message: {message}"
model_builder_error_code = 1

def __init__(self, message):
super().__init__(message=message)


class LocalModelOutOfMemoryException(ModelBuilderException):
"""Raise when local model serving fails to load the model"""

Expand Down
44 changes: 44 additions & 0 deletions src/sagemaker/serve/utils/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sagemaker import Session
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
from sagemaker.serve.mode.in_process_mode import InProcessMode
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serializers import IdentitySerializer, JSONSerializer
from sagemaker.deserializers import BytesDeserializer, JSONDeserializer
Expand Down Expand Up @@ -209,6 +210,49 @@ def delete_predictor(self):
self._mode_obj.destroy_server()


class TransformersInProcessModePredictor(PredictorBase):
"""Lightweight Transformers predictor for local deployment"""

def __init__(
self,
mode_obj: Type[InProcessMode],
serializer=JSONSerializer(),
deserializer=JSONDeserializer(),
):
self._mode_obj = mode_obj
self.serializer = serializer
self.deserializer = deserializer

def predict(self, data):
"""Placeholder docstring"""
return [
self.deserializer.deserialize(
io.BytesIO(
self._mode_obj._invoke_multi_model_server_serving(
self.serializer.serialize(data),
self.content_type,
self.deserializer.ACCEPT[0],
)
),
self.content_type,
)
]

@property
def content_type(self):
"""The MIME type of the data sent to the inference endpoint."""
return self.serializer.CONTENT_TYPE

@property
def accept(self):
"""The content type(s) that are expected from the inference endpoint."""
return self.deserializer.ACCEPT

def delete_predictor(self):
"""Shut down and remove the container that you created in LOCAL_CONTAINER mode"""
self._mode_obj.destroy_server()


class TeiLocalModePredictor(PredictorBase):
"""Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes"""

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@

class TestModelBuilder(unittest.TestCase):
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
builder = ModelBuilder()
def test_validation_in_progress_mode_supported(self, mock_serveSettings):
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
self.assertRaisesRegex(
Exception,
"IN_PROCESS mode is not supported yet!",
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release.",
builder.build,
Mode.IN_PROCESS,
mock_role_arn,
Expand Down
Loading