|
| 1 | +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""This module contains functionality to configure and start Torchserve.""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +import os |
| 17 | +import signal |
| 18 | +import subprocess |
| 19 | +import sys |
| 20 | + |
| 21 | +import pkg_resources |
| 22 | +import psutil |
| 23 | +import logging |
| 24 | +from retrying import retry |
| 25 | + |
| 26 | +import sagemaker_pytorch_serving_container |
| 27 | +from sagemaker_inference import default_handler_service, environment, utils |
| 28 | +from sagemaker_inference.environment import code_dir |
| 29 | + |
| 30 | +logger = logging.getLogger() |
| 31 | + |
| 32 | +TS_CONFIG_FILE = os.path.join("/etc", "sagemaker-ts.properties") |
| 33 | +DEFAULT_HANDLER_SERVICE = default_handler_service.__name__ |
| 34 | +DEFAULT_TS_CONFIG_FILE = pkg_resources.resource_filename( |
| 35 | + sagemaker_pytorch_serving_container.__name__, "/etc/default-ts.properties" |
| 36 | +) |
| 37 | +MME_TS_CONFIG_FILE = pkg_resources.resource_filename( |
| 38 | + sagemaker_pytorch_serving_container.__name__, "/etc/mme-ts.properties" |
| 39 | +) |
| 40 | +DEFAULT_TS_LOG_FILE = pkg_resources.resource_filename( |
| 41 | + sagemaker_pytorch_serving_container.__name__, "/etc/log4j.properties" |
| 42 | +) |
| 43 | +DEFAULT_TS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker", "ts", "models") |
| 44 | +DEFAULT_TS_MODEL_NAME = "model" |
| 45 | +DEFAULT_TS_MODEL_SERIALIZED_FILE = "model.pth" |
| 46 | +DEFAULT_HANDLER_SERVICE = "sagemaker_pytorch_serving_container.handler_service" |
| 47 | + |
| 48 | +ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" |
| 49 | +MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_TS_MODEL_DIRECTORY |
| 50 | + |
| 51 | +PYTHON_PATH_ENV = "PYTHONPATH" |
| 52 | +REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt") |
| 53 | +TS_NAMESPACE = "org.pytorch.serve.ModelServer" |
| 54 | + |
| 55 | + |
| 56 | +def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE): |
| 57 | + """Configure and start the model server. |
| 58 | +
|
| 59 | + Args: |
| 60 | + handler_service (str): Python path pointing to a module that defines |
| 61 | + a class with the following: |
| 62 | +
|
| 63 | + - A ``handle`` method, which is invoked for all incoming inference |
| 64 | + requests to the model server. |
| 65 | + - A ``initialize`` method, which is invoked at model server start up |
| 66 | + for loading the model. |
| 67 | +
|
| 68 | + Defaults to ``sagemaker_pytorch_serving_container.default_handler_service``. |
| 69 | +
|
| 70 | + """ |
| 71 | + |
| 72 | + if ENABLE_MULTI_MODEL: |
| 73 | + if "SAGEMAKER_HANDLER" not in os.environ: |
| 74 | + os.environ["SAGEMAKER_HANDLER"] = handler_service |
| 75 | + _set_python_path() |
| 76 | + else: |
| 77 | + _adapt_to_ts_format(handler_service) |
| 78 | + |
| 79 | + _create_torchserve_config_file() |
| 80 | + |
| 81 | + if os.path.exists(REQUIREMENTS_PATH): |
| 82 | + _install_requirements() |
| 83 | + |
| 84 | + ts_torchserve_cmd = [ |
| 85 | + "torchserve", |
| 86 | + "--start", |
| 87 | + "--model-store", |
| 88 | + MODEL_STORE, |
| 89 | + "--ts-config", |
| 90 | + TS_CONFIG_FILE, |
| 91 | + "--log-config", |
| 92 | + DEFAULT_TS_LOG_FILE, |
| 93 | + "--models", |
| 94 | + "model.mar" |
| 95 | + ] |
| 96 | + |
| 97 | + print(ts_torchserve_cmd) |
| 98 | + |
| 99 | + logger.info(ts_torchserve_cmd) |
| 100 | + subprocess.Popen(ts_torchserve_cmd) |
| 101 | + |
| 102 | + ts_process = _retrieve_ts_server_process() |
| 103 | + |
| 104 | + _add_sigterm_handler(ts_process) |
| 105 | + |
| 106 | + ts_process.wait() |
| 107 | + |
| 108 | + |
| 109 | +def _adapt_to_ts_format(handler_service): |
| 110 | + if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY): |
| 111 | + os.makedirs(DEFAULT_TS_MODEL_DIRECTORY) |
| 112 | + |
| 113 | + model_archiver_cmd = [ |
| 114 | + "torch-model-archiver", |
| 115 | + "--model-name", |
| 116 | + DEFAULT_TS_MODEL_NAME, |
| 117 | + "--handler", |
| 118 | + handler_service, |
| 119 | + "--serialized-file", |
| 120 | + os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE), |
| 121 | + "--export-path", |
| 122 | + DEFAULT_TS_MODEL_DIRECTORY, |
| 123 | + "--extra-files", |
| 124 | + os.path.join(environment.model_dir, environment.Environment().module_name + ".py"), |
| 125 | + "--version", |
| 126 | + "1", |
| 127 | + ] |
| 128 | + |
| 129 | + logger.info(model_archiver_cmd) |
| 130 | + subprocess.check_call(model_archiver_cmd) |
| 131 | + |
| 132 | + _set_python_path() |
| 133 | + |
| 134 | + |
| 135 | +def _set_python_path(): |
| 136 | + # Torchserve handles code execution by appending the export path, provided |
| 137 | + # to the model archiver, to the PYTHONPATH env var. |
| 138 | + # The code_dir has to be added to the PYTHONPATH otherwise the |
| 139 | + # user provided module can not be imported properly. |
| 140 | + if PYTHON_PATH_ENV in os.environ: |
| 141 | + os.environ[PYTHON_PATH_ENV] = "{}:{}".format(environment.code_dir, os.environ[PYTHON_PATH_ENV]) |
| 142 | + else: |
| 143 | + os.environ[PYTHON_PATH_ENV] = environment.code_dir |
| 144 | + |
| 145 | + |
| 146 | +def _create_torchserve_config_file(): |
| 147 | + configuration_properties = _generate_ts_config_properties() |
| 148 | + |
| 149 | + utils.write_file(TS_CONFIG_FILE, configuration_properties) |
| 150 | + |
| 151 | + |
| 152 | +def _generate_ts_config_properties(): |
| 153 | + env = environment.Environment() |
| 154 | + |
| 155 | + user_defined_configuration = { |
| 156 | + "default_response_timeout": env.model_server_timeout, |
| 157 | + "default_workers_per_model": env.model_server_workers, |
| 158 | + "inference_address": "http://0.0.0.0:{}".format(env.inference_http_port), |
| 159 | + "management_address": "http://0.0.0.0:{}".format(env.management_http_port), |
| 160 | + } |
| 161 | + |
| 162 | + custom_configuration = str() |
| 163 | + |
| 164 | + for key in user_defined_configuration: |
| 165 | + value = user_defined_configuration.get(key) |
| 166 | + if value: |
| 167 | + custom_configuration += "{}={}\n".format(key, value) |
| 168 | + |
| 169 | + if ENABLE_MULTI_MODEL: |
| 170 | + default_configuration = utils.read_file(MME_TS_CONFIG_FILE) |
| 171 | + else: |
| 172 | + default_configuration = utils.read_file(DEFAULT_TS_CONFIG_FILE) |
| 173 | + |
| 174 | + return default_configuration + custom_configuration |
| 175 | + |
| 176 | + |
| 177 | +def _add_sigterm_handler(ts_process): |
| 178 | + def _terminate(signo, frame): # pylint: disable=unused-argument |
| 179 | + try: |
| 180 | + os.kill(ts_process.pid, signal.SIGTERM) |
| 181 | + except OSError: |
| 182 | + pass |
| 183 | + |
| 184 | + signal.signal(signal.SIGTERM, _terminate) |
| 185 | + |
| 186 | + |
| 187 | +def _install_requirements(): |
| 188 | + logger.info("installing packages from requirements.txt...") |
| 189 | + pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH] |
| 190 | + |
| 191 | + try: |
| 192 | + subprocess.check_call(pip_install_cmd) |
| 193 | + except subprocess.CalledProcessError: |
| 194 | + logger.exception("failed to install required packages, exiting") |
| 195 | + raise ValueError("failed to install required packages") |
| 196 | + |
| 197 | + |
| 198 | +# retry for 10 seconds |
| 199 | +@retry(stop_max_delay=10 * 1000) |
| 200 | +def _retrieve_ts_server_process(): |
| 201 | + ts_server_processes = list() |
| 202 | + |
| 203 | + for process in psutil.process_iter(): |
| 204 | + if TS_NAMESPACE in process.cmdline(): |
| 205 | + ts_server_processes.append(process) |
| 206 | + |
| 207 | + if not ts_server_processes: |
| 208 | + raise Exception("Torchserve model server was unsuccessfully started") |
| 209 | + |
| 210 | + if len(ts_server_processes) > 1: |
| 211 | + raise Exception("multiple ts model servers are not supported") |
| 212 | + |
| 213 | + return ts_server_processes[0] |
0 commit comments