Skip to content

Torchserve support for PyTorch Inference #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker_inference/etc/default-ts.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/configuration.md
enable_envvars_config=true
decode_input_request=false
load_models=ALL
50 changes: 50 additions & 0 deletions src/sagemaker_inference/etc/ts.log4j.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
log4j.rootLogger = WARN, console

log4j.appender.console = org.apache.log4j.ConsoleAppender
log4j.appender.console.Target = System.out
log4j.appender.console.layout = org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n

log4j.appender.access_log = org.apache.log4j.RollingFileAppender
log4j.appender.access_log.File = ${LOG_LOCATION}/access_log.log
log4j.appender.access_log.MaxFileSize = 10MB
log4j.appender.access_log.MaxBackupIndex = 5
log4j.appender.access_log.layout = org.apache.log4j.PatternLayout
log4j.appender.access_log.layout.ConversionPattern = %d{ISO8601} - %m%n

log4j.appender.ts_log = org.apache.log4j.RollingFileAppender
log4j.appender.ts_log.File = ${LOG_LOCATION}/ts_log.log
log4j.appender.ts_log.MaxFileSize = 10MB
log4j.appender.ts_log.MaxBackupIndex = 5
log4j.appender.ts_log.layout = org.apache.log4j.PatternLayout
log4j.appender.ts_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n

log4j.appender.ts_metrics = org.apache.log4j.RollingFileAppender
log4j.appender.ts_metrics.File = ${METRICS_LOCATION}/ts_metrics.log
log4j.appender.ts_metrics.MaxFileSize = 10MB
log4j.appender.ts_metrics.MaxBackupIndex = 5
log4j.appender.ts_metrics.layout = org.apache.log4j.PatternLayout
log4j.appender.ts_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n

log4j.appender.model_log = org.apache.log4j.RollingFileAppender
log4j.appender.model_log.File = ${LOG_LOCATION}/model_log.log
log4j.appender.model_log.MaxFileSize = 10MB
log4j.appender.model_log.MaxBackupIndex = 5
log4j.appender.model_log.layout = org.apache.log4j.PatternLayout
log4j.appender.model_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %c - %m%n

log4j.appender.model_metrics = org.apache.log4j.RollingFileAppender
log4j.appender.model_metrics.File = ${METRICS_LOCATION}/model_metrics.log
log4j.appender.model_metrics.MaxFileSize = 10MB
log4j.appender.model_metrics.MaxBackupIndex = 5
log4j.appender.model_metrics.layout = org.apache.log4j.PatternLayout
log4j.appender.model_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n

log4j.logger.com.amazonaws.ml.ts = INFO, ts_log
log4j.logger.ACCESS_LOG = INFO, access_log
log4j.logger.TS_METRICS = WARN, ts_metrics
log4j.logger.MODEL_METRICS = WARN, model_metrics
log4j.logger.MODEL_LOG = WARN, model_log

log4j.logger.org.apache = OFF
log4j.logger.io.netty = ERROR
64 changes: 6 additions & 58 deletions src/sagemaker_inference/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import sagemaker_inference
from sagemaker_inference import default_handler_service, environment, logging, utils
from sagemaker_inference.model_server_utils import add_sigterm_handler, set_python_path, install_requirements, retrieve_model_server_process
from sagemaker_inference.environment import code_dir

logger = logging.get_logger()
Expand Down Expand Up @@ -70,14 +71,14 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
if ENABLE_MULTI_MODEL:
if not os.getenv("SAGEMAKER_HANDLER"):
os.environ["SAGEMAKER_HANDLER"] = handler_service
_set_python_path()
set_python_path()
else:
_adapt_to_mms_format(handler_service)

_create_model_server_config_file()

if os.path.exists(REQUIREMENTS_PATH):
_install_requirements()
install_requirements()

mxnet_model_server_cmd = [
"mxnet-model-server",
Expand All @@ -93,9 +94,9 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
logger.info(mxnet_model_server_cmd)
subprocess.Popen(mxnet_model_server_cmd)

mms_process = _retrieve_mms_server_process()
mms_process = retrieve_model_server_process(MMS_NAMESPACE)

_add_sigterm_handler(mms_process)
add_sigterm_handler(mms_process)

mms_process.wait()

Expand All @@ -121,21 +122,7 @@ def _adapt_to_mms_format(handler_service):
logger.info(model_archiver_cmd)
subprocess.check_call(model_archiver_cmd)

_set_python_path()


def _set_python_path():
# MMS handles code execution by appending the export path, provided
# to the model archiver, to the PYTHONPATH env var.
# The code_dir has to be added to the PYTHONPATH otherwise the
# user provided module can not be imported properly.
code_dir_path = "{}:".format(environment.code_dir)

if PYTHON_PATH_ENV in os.environ:
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
else:
os.environ[PYTHON_PATH_ENV] = code_dir_path

set_python_path()

def _create_model_server_config_file():
configuration_properties = _generate_mms_config_properties()
Expand Down Expand Up @@ -166,42 +153,3 @@ def _generate_mms_config_properties():
default_configuration = utils.read_file(DEFAULT_MMS_CONFIG_FILE)

return default_configuration + custom_configuration


def _add_sigterm_handler(mms_process):
def _terminate(signo, frame): # pylint: disable=unused-argument
try:
os.kill(mms_process.pid, signal.SIGTERM)
except OSError:
pass

signal.signal(signal.SIGTERM, _terminate)


def _install_requirements():
logger.info("installing packages from requirements.txt...")
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]

try:
subprocess.check_call(pip_install_cmd)
except subprocess.CalledProcessError:
logger.error("failed to install required packages, exiting")
raise ValueError("failed to install required packages")


# retry for 10 seconds
@retry(stop_max_delay=10 * 1000)
def _retrieve_mms_server_process():
mms_server_processes = list()

for process in psutil.process_iter():
if MMS_NAMESPACE in process.cmdline():
mms_server_processes.append(process)

if not mms_server_processes:
raise Exception("mms model server was unsuccessfully started")

if len(mms_server_processes) > 1:
raise Exception("multiple mms model servers are not supported")

return mms_server_processes[0]
66 changes: 66 additions & 0 deletions src/sagemaker_inference/model_server_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import signal
import subprocess
import sys

import pkg_resources
import psutil
from retrying import retry

import sagemaker_inference
from sagemaker_inference import environment, logging, utils
from sagemaker_inference.environment import code_dir

PYTHON_PATH_ENV = "PYTHONPATH"
logger = logging.get_logger()
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")

def add_sigterm_handler(mms_process):
def _terminate(signo, frame): # pylint: disable=unused-argument
try:
os.kill(mms_process.pid, signal.SIGTERM)
except OSError:
pass

signal.signal(signal.SIGTERM, _terminate)


def set_python_path():
# MMS handles code execution by appending the export path, provided
# to the model archiver, to the PYTHONPATH env var.
# The code_dir has to be added to the PYTHONPATH otherwise the
# user provided module can not be imported properly.
code_dir_path = "{}:".format(environment.code_dir)

if PYTHON_PATH_ENV in os.environ:
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
else:
os.environ[PYTHON_PATH_ENV] = code_dir_path

def install_requirements():
logger.info("installing packages from requirements.txt...")
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]

try:
subprocess.check_call(pip_install_cmd)
except subprocess.CalledProcessError:
logger.error("failed to install required packages, exiting")
raise ValueError("failed to install required packages")


# retry for 10 seconds
@retry(stop_max_delay=10 * 1000)
def retrieve_model_server_process(namespace):
model_server_processes = list()

for process in psutil.process_iter():
if namespace in process.cmdline():
model_server_processes.append(process)

if not model_server_processes:
raise Exception("model server was unsuccessfully started")

if len(model_server_processes) > 1:
raise Exception("multiple model servers are not supported")

return model_server_processes[0]
163 changes: 163 additions & 0 deletions src/sagemaker_inference/torchserve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This package is designed to be framework agnostic.
Torchserve belongs either in pytorch-serving or in its own repository.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closing this PR as this has been moved to aws/sagemaker-pytorch-inference-toolkit#79

#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functionality to configure and start the
multi-model server."""
from __future__ import absolute_import

import os
import signal
import subprocess
import sys
import importlib

import pkg_resources
import psutil
from retrying import retry

import sagemaker_inference
from sagemaker_inference import default_handler_service, environment, logging, utils
from sagemaker_inference.model_server_utils import add_sigterm_handler, set_python_path, install_requirements, retrieve_model_server_process
from sagemaker_inference.environment import code_dir

logger = logging.get_logger()

TS_CONFIG_FILE = os.path.join("/etc", "sagemaker-ts.properties")
DEFAULT_TS_CONFIG_FILE = pkg_resources.resource_filename(
sagemaker_inference.__name__, "/etc/default-ts.properties"
)
MME_TS_CONFIG_FILE = pkg_resources.resource_filename(
sagemaker_inference.__name__, "/etc/mme-ts.properties"
)
DEFAULT_TS_LOG_FILE = pkg_resources.resource_filename(
sagemaker_inference.__name__, "/etc/ts.log4j.properties"
)
DEFAULT_TS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/ts/models")
DEFAULT_TS_MODEL_NAME = "model"
DEFAULT_TS_MODEL_SERIALIZED_FILE = "model.pth"
DEFAULT_TS_HANDLER_SERVICE = "sagemaker_pytorch_serving_container.handler_service"

ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_TS_MODEL_DIRECTORY

PYTHON_PATH_ENV = "PYTHONPATH"
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
TS_NAMESPACE = "org.pytorch.serve.ModelServer"


def start_model_server(handler_service=DEFAULT_TS_HANDLER_SERVICE):
"""Configure and start the model server.
Args:
handler_service (str): python path pointing to a module that defines
a class with the following:
- A ``handle`` method, which is invoked for all incoming inference
requests to the model server.
- A ``initialize`` method, which is invoked at model server start up
for loading the model.
Defaults to ``sagemaker_inference.default_handler_service``.
"""

if ENABLE_MULTI_MODEL:
if not os.getenv("SAGEMAKER_HANDLER"):
os.environ["SAGEMAKER_HANDLER"] = handler_service
set_python_path()
else:
_adapt_to_ts_format(handler_service)

_create_torchserve_config_file()

if os.path.exists(REQUIREMENTS_PATH):
install_requirements()

ts_model_server_cmd = [
"torchserve",
"--start",
"--model-store",
MODEL_STORE,
"--ts-config",
TS_CONFIG_FILE,
"--log-config",
DEFAULT_TS_LOG_FILE,
"--models",
"model.mar"
]

logger.info(ts_model_server_cmd)
subprocess.Popen(ts_model_server_cmd)

ts_process = retrieve_model_server_process(TS_NAMESPACE)

add_sigterm_handler(ts_process)

ts_process.wait()


def _adapt_to_ts_format(handler_service):
if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY):
os.makedirs(DEFAULT_TS_MODEL_DIRECTORY)


model_archiver_cmd = [
"torch-model-archiver",
"--model-name",
DEFAULT_TS_MODEL_NAME,
"--handler",
handler_service,
"--serialized-file",
os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE),
"--export-path",
DEFAULT_TS_MODEL_DIRECTORY,
"--extra-files",
os.path.join(environment.model_dir, environment.Environment().module_name + ".py"),
"--version",
"1",
]

logger.info(model_archiver_cmd)
subprocess.check_call(model_archiver_cmd)

set_python_path()


def _create_torchserve_config_file():
configuration_properties = _generate_ts_config_properties()

utils.write_file(TS_CONFIG_FILE, configuration_properties)


def _generate_ts_config_properties():
env = environment.Environment()

user_defined_configuration = {
"default_response_timeout": env.model_server_timeout,
"default_workers_per_model": env.model_server_workers,
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
}

custom_configuration = str()

for key in user_defined_configuration:
value = user_defined_configuration.get(key)
if value:
custom_configuration += "{}={}\n".format(key, value)

if ENABLE_MULTI_MODEL:
default_configuration = utils.read_file(MME_TS_CONFIG_FILE)
else:
default_configuration = utils.read_file(DEFAULT_TS_CONFIG_FILE)

return default_configuration + custom_configuration
Loading