-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Support for ModelBuilder In_Process Mode (1/2) #4784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
2cc906b
b25295a
fb28458
3576ea9
d3b8e9b
68cede1
02e54ef
f39cca6
cc0ca14
18fc3f2
495c7b4
1121f47
b6062a7
1ec209c
ca6c818
cd3dbaa
f52f36c
1843210
d0fe3ac
1b93244
b40f36c
68000e1
826c5c4
1fd6291
de6f861
5cc24ba
64efa90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,10 @@ | |
) | ||
from sagemaker.serve.detector.pickler import save_pkl | ||
from sagemaker.serve.utils.optimize_utils import _is_optimized | ||
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor | ||
from sagemaker.serve.utils.predictors import ( | ||
TransformersLocalModePredictor, | ||
TransformersInProcessModePredictor, | ||
) | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.mode.function_pointers import Mode | ||
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry | ||
|
@@ -161,7 +164,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]: | |
vpc_config=self.vpc_config, | ||
) | ||
|
||
if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER: | ||
if self.mode == Mode.LOCAL_CONTAINER or self.mode == Mode.IN_PROCESS: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion, I've made the edit, thank you |
||
self.image_uri = pysdk_model.serving_image_uri( | ||
self.sagemaker_session.boto_region_name, "local" | ||
) | ||
|
@@ -227,6 +230,22 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr | |
) | ||
return predictor | ||
|
||
if self.mode == Mode.IN_PROCESS: | ||
timeout = kwargs.get("model_data_download_timeout") | ||
|
||
predictor = TransformersInProcessModePredictor( | ||
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer | ||
) | ||
|
||
self.modes[str(Mode.IN_PROCESS)].create_server( | ||
self.image_uri, | ||
timeout if timeout else DEFAULT_TIMEOUT, | ||
None, | ||
predictor, | ||
self.pysdk_model.env, | ||
) | ||
return predictor | ||
|
||
if "mode" in kwargs: | ||
del kwargs["mode"] | ||
if "role" in kwargs: | ||
|
@@ -274,7 +293,7 @@ def _build_transformers_env(self): | |
|
||
self.pysdk_model = self._create_transformers_model() | ||
|
||
if self.mode == Mode.LOCAL_CONTAINER: | ||
if self.mode == Mode.LOCAL_CONTAINER or self.mode == Mode.IN_PROCESS: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made the change |
||
self._prepare_for_mode() | ||
|
||
return self.pysdk_model | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
"""Module that defines the InProcessMode class""" | ||
|
||
from __future__ import absolute_import | ||
from pathlib import Path | ||
import logging | ||
from typing import Dict, Type | ||
import time | ||
|
||
from sagemaker.base_predictor import PredictorBase | ||
from sagemaker.serve.spec.inference_spec import InferenceSpec | ||
from sagemaker.serve.builder.schema_builder import SchemaBuilder | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.utils.exceptions import LocalDeepPingException | ||
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer | ||
from sagemaker.session import Session | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_PING_HEALTH_CHECK_INTERVAL_SEC = 5 | ||
|
||
_PING_HEALTH_CHECK_FAIL_MSG = ( | ||
"Container did not pass the ping health check. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does IN_PROCESS mode uses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it does not, I will be sure to change this, good catch. |
||
+ "Please increase container_timeout_seconds or review your inference code." | ||
) | ||
|
||
|
||
class InProcessMode( | ||
InProcessMultiModelServer, | ||
): | ||
"""A class that holds methods to deploy model to a container in process environment""" | ||
|
||
def __init__( | ||
self, | ||
model_server: ModelServer, | ||
inference_spec: Type[InferenceSpec], | ||
schema_builder: Type[SchemaBuilder], | ||
session: Session, | ||
model_path: str = None, | ||
env_vars: Dict = None, | ||
): | ||
# pylint: disable=bad-super-call | ||
super().__init__() | ||
|
||
self.inference_spec = inference_spec | ||
self.model_path = model_path | ||
self.env_vars = env_vars | ||
self.session = session | ||
self.schema_builder = schema_builder | ||
self.ecr = session.boto_session.client("ecr") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can remove self.ecr and container specific things here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed. |
||
self.model_server = model_server | ||
self.client = None | ||
self.container = None | ||
self.secret_key = None | ||
self._ping_container = None | ||
self._invoke_serving = None | ||
|
||
def load(self, model_path: str = None): | ||
"""Placeholder docstring""" | ||
path = Path(model_path if model_path else self.model_path) | ||
if not path.exists(): | ||
raise Exception("model_path does not exist") | ||
if not path.is_dir(): | ||
raise Exception("model_path is not a valid directory") | ||
|
||
return self.inference_spec.load(str(path)) | ||
|
||
def prepare(self): | ||
"""Placeholder docstring""" | ||
|
||
def create_server( | ||
self, | ||
image: str, | ||
secret_key: str, | ||
predictor: PredictorBase, | ||
env_vars: Dict[str, str] = None, | ||
model_path: str = None, | ||
): | ||
"""Placeholder docstring""" | ||
|
||
# self._pull_image(image=image) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can remove this this comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed |
||
|
||
# self.destroy_server() | ||
|
||
logger.info("Waiting for model server %s to start up...", self.model_server) | ||
|
||
if self.model_server == ModelServer.MMS: | ||
self._start_serving( | ||
client=self.client, | ||
image=image, | ||
model_path=model_path if model_path else self.model_path, | ||
secret_key=secret_key, | ||
env_vars=env_vars if env_vars else self.env_vars, | ||
) | ||
logger.info("Starting PING") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this log line too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed :) |
||
self._ping_container = self._multi_model_server_deep_ping | ||
|
||
while True: | ||
time.sleep(10) | ||
|
||
healthy, response = self._ping_container(predictor) | ||
if healthy: | ||
logger.debug("Ping health check has passed. Returned %s", str(response)) | ||
break | ||
|
||
if not healthy: | ||
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,83 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
class InProcessMultiModelServer: | ||
"""In Process Mode Multi Model server instance""" | ||
|
||
def _start_serving( | ||
self, | ||
client: object, | ||
image: str, | ||
model_path: str, | ||
secret_key: str, | ||
env_vars: dict, | ||
): | ||
"""Placeholder docstring""" | ||
env = { | ||
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", | ||
"SAGEMAKER_PROGRAM": "inference.py", | ||
"SAGEMAKER_SERVE_SECRET_KEY": secret_key, | ||
"LOCAL_PYTHON": platform.python_version(), | ||
} | ||
if env_vars: | ||
env_vars.update(env) | ||
else: | ||
env_vars = env | ||
|
||
self.container = client.containers.run( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we spinning up a docker container or using fast api for serving? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The container will be stubbed, this is only 1/2 of the full implementation of InProcess mode. My next PR will include the FastAPI. |
||
image, | ||
"serve", | ||
network_mode="host", | ||
detach=True, | ||
auto_remove=True, | ||
volumes={ | ||
Path(model_path).joinpath("code"): { | ||
"bind": MODE_DIR_BINDING, | ||
"mode": "rw", | ||
}, | ||
}, | ||
environment=env_vars, | ||
) | ||
|
||
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would leave these methods as stubs .... return an Exception("Not implemented") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have stubbed it, thank you. |
||
"""Placeholder docstring""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update doc strings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstrings are updated, good catch. |
||
logger.info(content_type) | ||
logger.info(accept) | ||
|
||
try: | ||
response = requests.post( | ||
"http://0.0.0.0:8080/invocations", | ||
data=request, | ||
headers={"Content-Type": content_type, "Accept": accept}, | ||
timeout=600, | ||
) | ||
response.raise_for_status() | ||
|
||
logger.info(response.content) | ||
|
||
return response.content | ||
except Exception as e: | ||
raise Exception("Unable to send request to the local container server") from e | ||
|
||
return (True, response) | ||
|
||
def _multi_model_server_deep_ping(self, predictor: PredictorBase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this complete? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have stubbed it. |
||
"""Placeholder docstring""" | ||
response = None | ||
logger.debug("AM I HERE? PING PING") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed it, thank you |
||
|
||
# try: | ||
# response = predictor.predict(self.schema_builder.sample_input) | ||
# return True, response | ||
# # pylint: disable=broad-except | ||
# except Exception as e: | ||
# if "422 Client Error: Unprocessable Entity for url" in str(e): | ||
# raise LocalModelInvocationException(str(e)) | ||
# return False, response | ||
|
||
return (True, response) | ||
|
||
|
||
class LocalMultiModelServer: | ||
"""Local Multi Model server instance""" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Placeholder Docstring""" | ||
"""Exceptions used across different model builder invocations""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice thanks for the update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course, thanks for the suggestion. |
||
|
||
from __future__ import absolute_import | ||
|
||
|
@@ -24,6 +24,16 @@ def __init__(self, message): | |
super().__init__(message=message) | ||
|
||
|
||
class InProcessDeepPingException(ModelBuilderException): | ||
"""Raise when in process model serving does not pass the deep ping check""" | ||
|
||
fmt = "Error Message: {message}" | ||
model_builder_error_code = 1 | ||
|
||
def __init__(self, message): | ||
super().__init__(message=message) | ||
|
||
|
||
class LocalModelOutOfMemoryException(ModelBuilderException): | ||
"""Raise when local model serving fails to load the model""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
IN_PROCESS mode is only supported for MMS/Transformers server in beta release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is better wording, I will change it thank you!