Skip to content

Commit 2bb3871

Browse files
committed
Torchserve support for PyTorch Inference
1 parent deef6a8 commit 2bb3871

File tree

8 files changed

+596
-157
lines changed

8 files changed

+596
-157
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/configuration.md
2+
enable_envvars_config=true
3+
decode_input_request=false
4+
load_models=ALL
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
log4j.rootLogger = WARN, console
2+
3+
log4j.appender.console = org.apache.log4j.ConsoleAppender
4+
log4j.appender.console.Target = System.out
5+
log4j.appender.console.layout = org.apache.log4j.PatternLayout
6+
log4j.appender.console.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n
7+
8+
log4j.appender.access_log = org.apache.log4j.RollingFileAppender
9+
log4j.appender.access_log.File = ${LOG_LOCATION}/access_log.log
10+
log4j.appender.access_log.MaxFileSize = 10MB
11+
log4j.appender.access_log.MaxBackupIndex = 5
12+
log4j.appender.access_log.layout = org.apache.log4j.PatternLayout
13+
log4j.appender.access_log.layout.ConversionPattern = %d{ISO8601} - %m%n
14+
15+
log4j.appender.ts_log = org.apache.log4j.RollingFileAppender
16+
log4j.appender.ts_log.File = ${LOG_LOCATION}/ts_log.log
17+
log4j.appender.ts_log.MaxFileSize = 10MB
18+
log4j.appender.ts_log.MaxBackupIndex = 5
19+
log4j.appender.ts_log.layout = org.apache.log4j.PatternLayout
20+
log4j.appender.ts_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n
21+
22+
log4j.appender.ts_metrics = org.apache.log4j.RollingFileAppender
23+
log4j.appender.ts_metrics.File = ${METRICS_LOCATION}/ts_metrics.log
24+
log4j.appender.ts_metrics.MaxFileSize = 10MB
25+
log4j.appender.ts_metrics.MaxBackupIndex = 5
26+
log4j.appender.ts_metrics.layout = org.apache.log4j.PatternLayout
27+
log4j.appender.ts_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n
28+
29+
log4j.appender.model_log = org.apache.log4j.RollingFileAppender
30+
log4j.appender.model_log.File = ${LOG_LOCATION}/model_log.log
31+
log4j.appender.model_log.MaxFileSize = 10MB
32+
log4j.appender.model_log.MaxBackupIndex = 5
33+
log4j.appender.model_log.layout = org.apache.log4j.PatternLayout
34+
log4j.appender.model_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %c - %m%n
35+
36+
log4j.appender.model_metrics = org.apache.log4j.RollingFileAppender
37+
log4j.appender.model_metrics.File = ${METRICS_LOCATION}/model_metrics.log
38+
log4j.appender.model_metrics.MaxFileSize = 10MB
39+
log4j.appender.model_metrics.MaxBackupIndex = 5
40+
log4j.appender.model_metrics.layout = org.apache.log4j.PatternLayout
41+
log4j.appender.model_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n
42+
43+
log4j.logger.com.amazonaws.ml.ts = INFO, ts_log
44+
log4j.logger.ACCESS_LOG = INFO, access_log
45+
log4j.logger.TS_METRICS = WARN, ts_metrics
46+
log4j.logger.MODEL_METRICS = WARN, model_metrics
47+
log4j.logger.MODEL_LOG = WARN, model_log
48+
49+
log4j.logger.org.apache = OFF
50+
log4j.logger.io.netty = ERROR

src/sagemaker_inference/model_server.py

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import sagemaker_inference
2727
from sagemaker_inference import default_handler_service, environment, logging, utils
28+
from sagemaker_inference.model_server_utils import add_sigterm_handler, set_python_path, install_requirements, retrieve_model_server_process
2829
from sagemaker_inference.environment import code_dir
2930

3031
logger = logging.get_logger()
@@ -70,14 +71,14 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
7071
if ENABLE_MULTI_MODEL:
7172
if not os.getenv("SAGEMAKER_HANDLER"):
7273
os.environ["SAGEMAKER_HANDLER"] = handler_service
73-
_set_python_path()
74+
set_python_path()
7475
else:
7576
_adapt_to_mms_format(handler_service)
7677

7778
_create_model_server_config_file()
7879

7980
if os.path.exists(REQUIREMENTS_PATH):
80-
_install_requirements()
81+
install_requirements()
8182

8283
mxnet_model_server_cmd = [
8384
"mxnet-model-server",
@@ -93,9 +94,9 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
9394
logger.info(mxnet_model_server_cmd)
9495
subprocess.Popen(mxnet_model_server_cmd)
9596

96-
mms_process = _retrieve_mms_server_process()
97+
mms_process = retrieve_model_server_process(MMS_NAMESPACE)
9798

98-
_add_sigterm_handler(mms_process)
99+
add_sigterm_handler(mms_process)
99100

100101
mms_process.wait()
101102

@@ -121,21 +122,7 @@ def _adapt_to_mms_format(handler_service):
121122
logger.info(model_archiver_cmd)
122123
subprocess.check_call(model_archiver_cmd)
123124

124-
_set_python_path()
125-
126-
127-
def _set_python_path():
128-
# MMS handles code execution by appending the export path, provided
129-
# to the model archiver, to the PYTHONPATH env var.
130-
# The code_dir has to be added to the PYTHONPATH otherwise the
131-
# user provided module can not be imported properly.
132-
code_dir_path = "{}:".format(environment.code_dir)
133-
134-
if PYTHON_PATH_ENV in os.environ:
135-
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
136-
else:
137-
os.environ[PYTHON_PATH_ENV] = code_dir_path
138-
125+
set_python_path()
139126

140127
def _create_model_server_config_file():
141128
configuration_properties = _generate_mms_config_properties()
@@ -166,42 +153,3 @@ def _generate_mms_config_properties():
166153
default_configuration = utils.read_file(DEFAULT_MMS_CONFIG_FILE)
167154

168155
return default_configuration + custom_configuration
169-
170-
171-
def _add_sigterm_handler(mms_process):
172-
def _terminate(signo, frame): # pylint: disable=unused-argument
173-
try:
174-
os.kill(mms_process.pid, signal.SIGTERM)
175-
except OSError:
176-
pass
177-
178-
signal.signal(signal.SIGTERM, _terminate)
179-
180-
181-
def _install_requirements():
182-
logger.info("installing packages from requirements.txt...")
183-
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]
184-
185-
try:
186-
subprocess.check_call(pip_install_cmd)
187-
except subprocess.CalledProcessError:
188-
logger.error("failed to install required packages, exiting")
189-
raise ValueError("failed to install required packages")
190-
191-
192-
# retry for 10 seconds
193-
@retry(stop_max_delay=10 * 1000)
194-
def _retrieve_mms_server_process():
195-
mms_server_processes = list()
196-
197-
for process in psutil.process_iter():
198-
if MMS_NAMESPACE in process.cmdline():
199-
mms_server_processes.append(process)
200-
201-
if not mms_server_processes:
202-
raise Exception("mms model server was unsuccessfully started")
203-
204-
if len(mms_server_processes) > 1:
205-
raise Exception("multiple mms model servers are not supported")
206-
207-
return mms_server_processes[0]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import signal
3+
import subprocess
4+
import sys
5+
6+
import pkg_resources
7+
import psutil
8+
from retrying import retry
9+
10+
import sagemaker_inference
11+
from sagemaker_inference import environment, logging, utils
12+
from sagemaker_inference.environment import code_dir
13+
14+
PYTHON_PATH_ENV = "PYTHONPATH"
15+
logger = logging.get_logger()
16+
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
17+
18+
def add_sigterm_handler(mms_process):
19+
def _terminate(signo, frame): # pylint: disable=unused-argument
20+
try:
21+
os.kill(mms_process.pid, signal.SIGTERM)
22+
except OSError:
23+
pass
24+
25+
signal.signal(signal.SIGTERM, _terminate)
26+
27+
28+
def set_python_path():
29+
# MMS handles code execution by appending the export path, provided
30+
# to the model archiver, to the PYTHONPATH env var.
31+
# The code_dir has to be added to the PYTHONPATH otherwise the
32+
# user provided module can not be imported properly.
33+
code_dir_path = "{}:".format(environment.code_dir)
34+
35+
if PYTHON_PATH_ENV in os.environ:
36+
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
37+
else:
38+
os.environ[PYTHON_PATH_ENV] = code_dir_path
39+
40+
def install_requirements():
41+
logger.info("installing packages from requirements.txt...")
42+
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]
43+
44+
try:
45+
subprocess.check_call(pip_install_cmd)
46+
except subprocess.CalledProcessError:
47+
logger.error("failed to install required packages, exiting")
48+
raise ValueError("failed to install required packages")
49+
50+
51+
# retry for 10 seconds
52+
@retry(stop_max_delay=10 * 1000)
53+
def retrieve_model_server_process(namespace):
54+
model_server_processes = list()
55+
56+
for process in psutil.process_iter():
57+
if namespace in process.cmdline():
58+
model_server_processes.append(process)
59+
60+
if not model_server_processes:
61+
raise Exception("model server was unsuccessfully started")
62+
63+
if len(model_server_processes) > 1:
64+
raise Exception("multiple model servers are not supported")
65+
66+
return model_server_processes[0]

src/sagemaker_inference/torchserve.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functionality to configure and start the
14+
multi-model server."""
15+
from __future__ import absolute_import
16+
17+
import os
18+
import signal
19+
import subprocess
20+
import sys
21+
import importlib
22+
23+
import pkg_resources
24+
import psutil
25+
from retrying import retry
26+
27+
import sagemaker_inference
28+
from sagemaker_inference import default_handler_service, environment, logging, utils
29+
from sagemaker_inference.model_server_utils import add_sigterm_handler, set_python_path, install_requirements, retrieve_model_server_process
30+
from sagemaker_inference.environment import code_dir
31+
32+
logger = logging.get_logger()
33+
34+
TS_CONFIG_FILE = os.path.join("/etc", "sagemaker-ts.properties")
35+
DEFAULT_TS_CONFIG_FILE = pkg_resources.resource_filename(
36+
sagemaker_inference.__name__, "/etc/default-ts.properties"
37+
)
38+
MME_TS_CONFIG_FILE = pkg_resources.resource_filename(
39+
sagemaker_inference.__name__, "/etc/mme-ts.properties"
40+
)
41+
DEFAULT_TS_LOG_FILE = pkg_resources.resource_filename(
42+
sagemaker_inference.__name__, "/etc/ts.log4j.properties"
43+
)
44+
DEFAULT_TS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/ts/models")
45+
DEFAULT_TS_MODEL_NAME = "model"
46+
DEFAULT_TS_MODEL_SERIALIZED_FILE = "model.pth"
47+
DEFAULT_TS_HANDLER_SERVICE = "sagemaker_pytorch_serving_container.handler_service"
48+
49+
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
50+
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_TS_MODEL_DIRECTORY
51+
52+
PYTHON_PATH_ENV = "PYTHONPATH"
53+
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
54+
TS_NAMESPACE = "org.pytorch.serve.ModelServer"
55+
56+
57+
def start_model_server(handler_service=DEFAULT_TS_HANDLER_SERVICE):
58+
"""Configure and start the model server.
59+
60+
Args:
61+
handler_service (str): python path pointing to a module that defines
62+
a class with the following:
63+
64+
- A ``handle`` method, which is invoked for all incoming inference
65+
requests to the model server.
66+
- A ``initialize`` method, which is invoked at model server start up
67+
for loading the model.
68+
69+
Defaults to ``sagemaker_inference.default_handler_service``.
70+
71+
"""
72+
73+
if ENABLE_MULTI_MODEL:
74+
if not os.getenv("SAGEMAKER_HANDLER"):
75+
os.environ["SAGEMAKER_HANDLER"] = handler_service
76+
set_python_path()
77+
else:
78+
_adapt_to_ts_format(handler_service)
79+
80+
_create_torchserve_config_file()
81+
82+
if os.path.exists(REQUIREMENTS_PATH):
83+
install_requirements()
84+
85+
ts_model_server_cmd = [
86+
"torchserve",
87+
"--start",
88+
"--model-store",
89+
MODEL_STORE,
90+
"--ts-config",
91+
TS_CONFIG_FILE,
92+
"--log-config",
93+
DEFAULT_TS_LOG_FILE,
94+
"--models",
95+
"model.mar"
96+
]
97+
98+
logger.info(ts_model_server_cmd)
99+
subprocess.Popen(ts_model_server_cmd)
100+
101+
ts_process = retrieve_model_server_process(TS_NAMESPACE)
102+
103+
add_sigterm_handler(ts_process)
104+
105+
ts_process.wait()
106+
107+
108+
def _adapt_to_ts_format(handler_service):
109+
if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY):
110+
os.makedirs(DEFAULT_TS_MODEL_DIRECTORY)
111+
112+
113+
model_archiver_cmd = [
114+
"torch-model-archiver",
115+
"--model-name",
116+
DEFAULT_TS_MODEL_NAME,
117+
"--handler",
118+
handler_service,
119+
"--serialized-file",
120+
os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE),
121+
"--export-path",
122+
DEFAULT_TS_MODEL_DIRECTORY,
123+
"--extra-files",
124+
os.path.join(environment.model_dir, environment.Environment().module_name + ".py"),
125+
"--version",
126+
"1",
127+
]
128+
129+
logger.info(model_archiver_cmd)
130+
subprocess.check_call(model_archiver_cmd)
131+
132+
set_python_path()
133+
134+
135+
def _create_torchserve_config_file():
136+
configuration_properties = _generate_ts_config_properties()
137+
138+
utils.write_file(TS_CONFIG_FILE, configuration_properties)
139+
140+
141+
def _generate_ts_config_properties():
142+
env = environment.Environment()
143+
144+
user_defined_configuration = {
145+
"default_response_timeout": env.model_server_timeout,
146+
"default_workers_per_model": env.model_server_workers,
147+
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
148+
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
149+
}
150+
151+
custom_configuration = str()
152+
153+
for key in user_defined_configuration:
154+
value = user_defined_configuration.get(key)
155+
if value:
156+
custom_configuration += "{}={}\n".format(key, value)
157+
158+
if ENABLE_MULTI_MODEL:
159+
default_configuration = utils.read_file(MME_TS_CONFIG_FILE)
160+
else:
161+
default_configuration = utils.read_file(DEFAULT_TS_CONFIG_FILE)
162+
163+
return default_configuration + custom_configuration

0 commit comments

Comments
 (0)