Skip to content

Commit a3a08d0

Browse files
authored
breaking: Change Model server to Torchserve for PyTorch Inference (#79)
1 parent c4e7abc commit a3a08d0

20 files changed

+761
-28
lines changed
File renamed without changes.

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def read(fname):
3131

3232
packages=find_packages(where='src', exclude=('test',)),
3333
package_dir={'': 'src'},
34+
package_data={'': ["etc/*"]},
3435
py_modules=[splitext(basename(path))[0] for path in glob('src/*.py')],
3536

3637
long_description=read('README.rst'),
@@ -56,7 +57,7 @@ def read(fname):
5657
'test': ['boto3==1.10.32', 'coverage==4.5.3', 'docker-compose==1.23.2', 'flake8==3.7.7', 'Flask==1.1.1',
5758
'mock==2.0.0', 'pytest==4.4.0', 'pytest-cov==2.7.1', 'pytest-xdist==1.28.0', 'PyYAML==3.10',
5859
'sagemaker==1.56.3', 'sagemaker-containers>=2.5.4', 'six==1.12.0', 'requests==2.20.0',
59-
'requests_mock==1.6.0', 'torch==1.5.0', 'torchvision==0.6.0', 'tox==3.7.0']
60+
'requests_mock==1.6.0', 'torch==1.6.0', 'torchvision==0.7.0', 'tox==3.7.0']
6061
},
6162

6263
entry_points={
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Based on https://github.com/pytorch/serve/blob/master/docs/configuration.md
2+
enable_envvars_config=true
3+
decode_input_request=false
4+
load_models=ALL
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
log4j.rootLogger = INFO, console
2+
3+
log4j.appender.console = org.apache.log4j.ConsoleAppender
4+
log4j.appender.console.Target = System.out
5+
log4j.appender.console.layout = org.apache.log4j.PatternLayout
6+
log4j.appender.console.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n
7+
8+
log4j.appender.access_log = org.apache.log4j.RollingFileAppender
9+
log4j.appender.access_log.File = ${LOG_LOCATION}/access_log.log
10+
log4j.appender.access_log.MaxFileSize = 10MB
11+
log4j.appender.access_log.MaxBackupIndex = 5
12+
log4j.appender.access_log.layout = org.apache.log4j.PatternLayout
13+
log4j.appender.access_log.layout.ConversionPattern = %d{ISO8601} - %m%n
14+
15+
log4j.appender.ts_log = org.apache.log4j.RollingFileAppender
16+
log4j.appender.ts_log.File = ${LOG_LOCATION}/ts_log.log
17+
log4j.appender.ts_log.MaxFileSize = 10MB
18+
log4j.appender.ts_log.MaxBackupIndex = 5
19+
log4j.appender.ts_log.layout = org.apache.log4j.PatternLayout
20+
log4j.appender.ts_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n
21+
22+
log4j.appender.ts_metrics = org.apache.log4j.RollingFileAppender
23+
log4j.appender.ts_metrics.File = ${METRICS_LOCATION}/ts_metrics.log
24+
log4j.appender.ts_metrics.MaxFileSize = 10MB
25+
log4j.appender.ts_metrics.MaxBackupIndex = 5
26+
log4j.appender.ts_metrics.layout = org.apache.log4j.PatternLayout
27+
log4j.appender.ts_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n
28+
29+
log4j.appender.model_log = org.apache.log4j.RollingFileAppender
30+
log4j.appender.model_log.File = ${LOG_LOCATION}/model_log.log
31+
log4j.appender.model_log.MaxFileSize = 10MB
32+
log4j.appender.model_log.MaxBackupIndex = 5
33+
log4j.appender.model_log.layout = org.apache.log4j.PatternLayout
34+
log4j.appender.model_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %c - %m%n
35+
36+
log4j.appender.model_metrics = org.apache.log4j.RollingFileAppender
37+
log4j.appender.model_metrics.File = ${METRICS_LOCATION}/model_metrics.log
38+
log4j.appender.model_metrics.MaxFileSize = 10MB
39+
log4j.appender.model_metrics.MaxBackupIndex = 5
40+
log4j.appender.model_metrics.layout = org.apache.log4j.PatternLayout
41+
log4j.appender.model_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n
42+
43+
log4j.logger.com.amazonaws.ml.ts = INFO, ts_log
44+
log4j.logger.ACCESS_LOG = INFO, access_log
45+
log4j.logger.TS_METRICS = INFO, ts_metrics
46+
log4j.logger.MODEL_METRICS = INFO, model_metrics
47+
log4j.logger.MODEL_LOG = INFO, model_log
48+
49+
log4j.logger.org.apache = OFF
50+
log4j.logger.io.netty = ERROR

src/sagemaker_pytorch_serving_container/handler_service.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
from sagemaker_inference.default_handler_service import DefaultHandlerService
1616
from sagemaker_inference.transformer import Transformer
17-
from sagemaker_pytorch_serving_container.default_inference_handler import \
18-
DefaultPytorchInferenceHandler
17+
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler
1918

2019
import os
2120
import sys

src/sagemaker_pytorch_serving_container/serving.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
from subprocess import CalledProcessError
1616

1717
from retrying import retry
18-
from sagemaker_inference import model_server
19-
18+
from sagemaker_pytorch_serving_container import torchserve
2019
from sagemaker_pytorch_serving_container import handler_service
2120

22-
HANDLER_SERVICE = handler_service.__name__
21+
HANDLER_SERVICE = handler_service.__file__
2322

2423

2524
def _retry_if_error(exception):
@@ -28,12 +27,12 @@ def _retry_if_error(exception):
2827

2928
@retry(stop_max_delay=1000 * 30,
3029
retry_on_exception=_retry_if_error)
31-
def _start_model_server():
30+
def _start_torchserve():
3231
# there's a race condition that causes the model server command to
3332
# sometimes fail with 'bad address'. more investigation needed
3433
# retry starting mms until it's ready
35-
model_server.start_model_server(handler_service=HANDLER_SERVICE)
34+
torchserve.start_torchserve(handler_service=HANDLER_SERVICE)
3635

3736

3837
def main():
39-
_start_model_server()
38+
_start_torchserve()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functionality to configure and start Torchserve."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
import signal
18+
import subprocess
19+
import sys
20+
21+
import pkg_resources
22+
import psutil
23+
import logging
24+
from retrying import retry
25+
26+
import sagemaker_pytorch_serving_container
27+
from sagemaker_inference import default_handler_service, environment, utils
28+
from sagemaker_inference.environment import code_dir
29+
30+
logger = logging.getLogger()
31+
32+
TS_CONFIG_FILE = os.path.join("/etc", "sagemaker-ts.properties")
33+
DEFAULT_HANDLER_SERVICE = default_handler_service.__name__
34+
DEFAULT_TS_CONFIG_FILE = pkg_resources.resource_filename(
35+
sagemaker_pytorch_serving_container.__name__, "/etc/default-ts.properties"
36+
)
37+
MME_TS_CONFIG_FILE = pkg_resources.resource_filename(
38+
sagemaker_pytorch_serving_container.__name__, "/etc/mme-ts.properties"
39+
)
40+
DEFAULT_TS_LOG_FILE = pkg_resources.resource_filename(
41+
sagemaker_pytorch_serving_container.__name__, "/etc/log4j.properties"
42+
)
43+
DEFAULT_TS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker", "ts", "models")
44+
DEFAULT_TS_MODEL_NAME = "model"
45+
DEFAULT_TS_MODEL_SERIALIZED_FILE = "model.pth"
46+
DEFAULT_HANDLER_SERVICE = "sagemaker_pytorch_serving_container.handler_service"
47+
48+
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
49+
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_TS_MODEL_DIRECTORY
50+
51+
PYTHON_PATH_ENV = "PYTHONPATH"
52+
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
53+
TS_NAMESPACE = "org.pytorch.serve.ModelServer"
54+
55+
56+
def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
57+
"""Configure and start the model server.
58+
59+
Args:
60+
handler_service (str): Python path pointing to a module that defines
61+
a class with the following:
62+
63+
- A ``handle`` method, which is invoked for all incoming inference
64+
requests to the model server.
65+
- A ``initialize`` method, which is invoked at model server start up
66+
for loading the model.
67+
68+
Defaults to ``sagemaker_pytorch_serving_container.default_handler_service``.
69+
70+
"""
71+
72+
if ENABLE_MULTI_MODEL:
73+
if "SAGEMAKER_HANDLER" not in os.environ:
74+
os.environ["SAGEMAKER_HANDLER"] = handler_service
75+
_set_python_path()
76+
else:
77+
_adapt_to_ts_format(handler_service)
78+
79+
_create_torchserve_config_file()
80+
81+
if os.path.exists(REQUIREMENTS_PATH):
82+
_install_requirements()
83+
84+
ts_torchserve_cmd = [
85+
"torchserve",
86+
"--start",
87+
"--model-store",
88+
MODEL_STORE,
89+
"--ts-config",
90+
TS_CONFIG_FILE,
91+
"--log-config",
92+
DEFAULT_TS_LOG_FILE,
93+
"--models",
94+
"model.mar"
95+
]
96+
97+
print(ts_torchserve_cmd)
98+
99+
logger.info(ts_torchserve_cmd)
100+
subprocess.Popen(ts_torchserve_cmd)
101+
102+
ts_process = _retrieve_ts_server_process()
103+
104+
_add_sigterm_handler(ts_process)
105+
106+
ts_process.wait()
107+
108+
109+
def _adapt_to_ts_format(handler_service):
110+
if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY):
111+
os.makedirs(DEFAULT_TS_MODEL_DIRECTORY)
112+
113+
model_archiver_cmd = [
114+
"torch-model-archiver",
115+
"--model-name",
116+
DEFAULT_TS_MODEL_NAME,
117+
"--handler",
118+
handler_service,
119+
"--serialized-file",
120+
os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE),
121+
"--export-path",
122+
DEFAULT_TS_MODEL_DIRECTORY,
123+
"--extra-files",
124+
os.path.join(environment.model_dir, environment.Environment().module_name + ".py"),
125+
"--version",
126+
"1",
127+
]
128+
129+
logger.info(model_archiver_cmd)
130+
subprocess.check_call(model_archiver_cmd)
131+
132+
_set_python_path()
133+
134+
135+
def _set_python_path():
136+
# Torchserve handles code execution by appending the export path, provided
137+
# to the model archiver, to the PYTHONPATH env var.
138+
# The code_dir has to be added to the PYTHONPATH otherwise the
139+
# user provided module can not be imported properly.
140+
if PYTHON_PATH_ENV in os.environ:
141+
os.environ[PYTHON_PATH_ENV] = "{}:{}".format(environment.code_dir, os.environ[PYTHON_PATH_ENV])
142+
else:
143+
os.environ[PYTHON_PATH_ENV] = environment.code_dir
144+
145+
146+
def _create_torchserve_config_file():
147+
configuration_properties = _generate_ts_config_properties()
148+
149+
utils.write_file(TS_CONFIG_FILE, configuration_properties)
150+
151+
152+
def _generate_ts_config_properties():
153+
env = environment.Environment()
154+
155+
user_defined_configuration = {
156+
"default_response_timeout": env.model_server_timeout,
157+
"default_workers_per_model": env.model_server_workers,
158+
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
159+
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
160+
}
161+
162+
custom_configuration = str()
163+
164+
for key in user_defined_configuration:
165+
value = user_defined_configuration.get(key)
166+
if value:
167+
custom_configuration += "{}={}\n".format(key, value)
168+
169+
if ENABLE_MULTI_MODEL:
170+
default_configuration = utils.read_file(MME_TS_CONFIG_FILE)
171+
else:
172+
default_configuration = utils.read_file(DEFAULT_TS_CONFIG_FILE)
173+
174+
return default_configuration + custom_configuration
175+
176+
177+
def _add_sigterm_handler(ts_process):
178+
def _terminate(signo, frame): # pylint: disable=unused-argument
179+
try:
180+
os.kill(ts_process.pid, signal.SIGTERM)
181+
except OSError:
182+
pass
183+
184+
signal.signal(signal.SIGTERM, _terminate)
185+
186+
187+
def _install_requirements():
188+
logger.info("installing packages from requirements.txt...")
189+
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]
190+
191+
try:
192+
subprocess.check_call(pip_install_cmd)
193+
except subprocess.CalledProcessError:
194+
logger.exception("failed to install required packages, exiting")
195+
raise ValueError("failed to install required packages")
196+
197+
198+
# retry for 10 seconds
199+
@retry(stop_max_delay=10 * 1000)
200+
def _retrieve_ts_server_process():
201+
ts_server_processes = list()
202+
203+
for process in psutil.process_iter():
204+
if TS_NAMESPACE in process.cmdline():
205+
ts_server_processes.append(process)
206+
207+
if not ts_server_processes:
208+
raise Exception("Torchserve model server was unsuccessfully started")
209+
210+
if len(ts_server_processes) > 1:
211+
raise Exception("multiple ts model servers are not supported")
212+
213+
return ts_server_processes[0]

test/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def pytest_addoption(parser):
5353
parser.addoption('--accelerator-type')
5454
parser.addoption('--docker-base-name', default='sagemaker-pytorch-inference')
5555
parser.addoption('--region', default='us-west-2')
56-
parser.addoption('--framework-version', default="1.5.0")
56+
parser.addoption('--framework-version', default="1.6.0")
5757
parser.addoption('--py-version', choices=['2', '3'], default='3')
5858
# Processor is still "cpu" for EIA tests
5959
parser.addoption('--processor', choices=['gpu', 'cpu'], default='cpu')
+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
ARG region
22
FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:1.5.0-cpu-py3
33

4+
ARG TS_VERSION=0.1.1
5+
RUN apt-get update \
6+
&& apt-get install -y --no-install-recommends software-properties-common \
7+
&& add-apt-repository ppa:openjdk-r/ppa \
8+
&& apt-get update \
9+
&& apt-get install -y --no-install-recommends openjdk-11-jdk
10+
11+
RUN pip install torchserve==$TS_VERSION \
12+
&& pip install torch-model-archiver==$TS_VERSION
13+
414
COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz
515
RUN pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \
616
rm /sagemaker_pytorch_inference.tar.gz
17+
18+
CMD ["torchserve", "--start", "--ts-config", "/home/model-server/config.properties", "--model-store", "/home/model-server/"]
+13-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
ARG region
2-
FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:1.5.0-gpu-py3
2+
FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:1.5.0-cpu-py3
3+
4+
ARG TS_VERSION=0.1.1
5+
RUN apt-get update \
6+
&& apt-get install -y --no-install-recommends software-properties-common \
7+
&& add-apt-repository ppa:openjdk-r/ppa \
8+
&& apt-get update \
9+
&& apt-get install -y --no-install-recommends openjdk-11-jdk
10+
11+
RUN pip install torchserve==$TS_VERSION \
12+
&& pip install torch-model-archiver==$TS_VERSION
313

414
COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz
515
RUN pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \
616
rm /sagemaker_pytorch_inference.tar.gz
17+
18+
CMD ["torchserve", "--start", "--ts-config", "/home/model-server/config.properties", "--model-store", "/home/model-server/"]

0 commit comments

Comments
 (0)