diff --git a/src/sagemaker_inference/etc/default-ts.properties b/src/sagemaker_inference/etc/default-ts.properties new file mode 100644 index 0000000..c95ac32 --- /dev/null +++ b/src/sagemaker_inference/etc/default-ts.properties @@ -0,0 +1,4 @@ +# Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/configuration.md +enable_envvars_config=true +decode_input_request=false +load_models=ALL diff --git a/src/sagemaker_inference/etc/ts.log4j.properties b/src/sagemaker_inference/etc/ts.log4j.properties new file mode 100644 index 0000000..fe94e50 --- /dev/null +++ b/src/sagemaker_inference/etc/ts.log4j.properties @@ -0,0 +1,50 @@ +log4j.rootLogger = WARN, console + +log4j.appender.console = org.apache.log4j.ConsoleAppender +log4j.appender.console.Target = System.out +log4j.appender.console.layout = org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n + +log4j.appender.access_log = org.apache.log4j.RollingFileAppender +log4j.appender.access_log.File = ${LOG_LOCATION}/access_log.log +log4j.appender.access_log.MaxFileSize = 10MB +log4j.appender.access_log.MaxBackupIndex = 5 +log4j.appender.access_log.layout = org.apache.log4j.PatternLayout +log4j.appender.access_log.layout.ConversionPattern = %d{ISO8601} - %m%n + +log4j.appender.ts_log = org.apache.log4j.RollingFileAppender +log4j.appender.ts_log.File = ${LOG_LOCATION}/ts_log.log +log4j.appender.ts_log.MaxFileSize = 10MB +log4j.appender.ts_log.MaxBackupIndex = 5 +log4j.appender.ts_log.layout = org.apache.log4j.PatternLayout +log4j.appender.ts_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %t %c - %m%n + +log4j.appender.ts_metrics = org.apache.log4j.RollingFileAppender +log4j.appender.ts_metrics.File = ${METRICS_LOCATION}/ts_metrics.log +log4j.appender.ts_metrics.MaxFileSize = 10MB +log4j.appender.ts_metrics.MaxBackupIndex = 5 +log4j.appender.ts_metrics.layout = org.apache.log4j.PatternLayout +log4j.appender.ts_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n + +log4j.appender.model_log = org.apache.log4j.RollingFileAppender +log4j.appender.model_log.File = ${LOG_LOCATION}/model_log.log +log4j.appender.model_log.MaxFileSize = 10MB +log4j.appender.model_log.MaxBackupIndex = 5 +log4j.appender.model_log.layout = org.apache.log4j.PatternLayout +log4j.appender.model_log.layout.ConversionPattern = %d{ISO8601} [%-5p] %c - %m%n + +log4j.appender.model_metrics = org.apache.log4j.RollingFileAppender +log4j.appender.model_metrics.File = ${METRICS_LOCATION}/model_metrics.log +log4j.appender.model_metrics.MaxFileSize = 10MB +log4j.appender.model_metrics.MaxBackupIndex = 5 +log4j.appender.model_metrics.layout = org.apache.log4j.PatternLayout +log4j.appender.model_metrics.layout.ConversionPattern = %d{ISO8601} - %m%n + +log4j.logger.com.amazonaws.ml.ts = INFO, ts_log +log4j.logger.ACCESS_LOG = INFO, access_log +log4j.logger.TS_METRICS = WARN, ts_metrics +log4j.logger.MODEL_METRICS = WARN, model_metrics +log4j.logger.MODEL_LOG = WARN, model_log + +log4j.logger.org.apache = OFF +log4j.logger.io.netty = ERROR diff --git a/src/sagemaker_inference/model_server.py b/src/sagemaker_inference/model_server.py index 754c63b..b017115 100644 --- a/src/sagemaker_inference/model_server.py +++ b/src/sagemaker_inference/model_server.py @@ -25,6 +25,7 @@ import sagemaker_inference from sagemaker_inference import default_handler_service, environment, logging, utils +from sagemaker_inference.model_server_utils import add_sigterm_handler, set_python_path, install_requirements, retrieve_model_server_process from sagemaker_inference.environment import code_dir logger = logging.get_logger() @@ -70,14 +71,14 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): if ENABLE_MULTI_MODEL: if not os.getenv("SAGEMAKER_HANDLER"): os.environ["SAGEMAKER_HANDLER"] = handler_service - _set_python_path() + set_python_path() else: _adapt_to_mms_format(handler_service) _create_model_server_config_file() if os.path.exists(REQUIREMENTS_PATH): - _install_requirements() + install_requirements() mxnet_model_server_cmd = [ "mxnet-model-server", @@ -93,9 +94,9 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): logger.info(mxnet_model_server_cmd) subprocess.Popen(mxnet_model_server_cmd) - mms_process = _retrieve_mms_server_process() + mms_process = retrieve_model_server_process(MMS_NAMESPACE) - _add_sigterm_handler(mms_process) + add_sigterm_handler(mms_process) mms_process.wait() @@ -121,21 +122,7 @@ def _adapt_to_mms_format(handler_service): logger.info(model_archiver_cmd) subprocess.check_call(model_archiver_cmd) - _set_python_path() - - -def _set_python_path(): - # MMS handles code execution by appending the export path, provided - # to the model archiver, to the PYTHONPATH env var. - # The code_dir has to be added to the PYTHONPATH otherwise the - # user provided module can not be imported properly. - code_dir_path = "{}:".format(environment.code_dir) - - if PYTHON_PATH_ENV in os.environ: - os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] - else: - os.environ[PYTHON_PATH_ENV] = code_dir_path - + set_python_path() def _create_model_server_config_file(): configuration_properties = _generate_mms_config_properties() @@ -166,42 +153,3 @@ def _generate_mms_config_properties(): default_configuration = utils.read_file(DEFAULT_MMS_CONFIG_FILE) return default_configuration + custom_configuration - - -def _add_sigterm_handler(mms_process): - def _terminate(signo, frame): # pylint: disable=unused-argument - try: - os.kill(mms_process.pid, signal.SIGTERM) - except OSError: - pass - - signal.signal(signal.SIGTERM, _terminate) - - -def _install_requirements(): - logger.info("installing packages from requirements.txt...") - pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH] - - try: - subprocess.check_call(pip_install_cmd) - except subprocess.CalledProcessError: - logger.error("failed to install required packages, exiting") - raise ValueError("failed to install required packages") - - -# retry for 10 seconds -@retry(stop_max_delay=10 * 1000) -def _retrieve_mms_server_process(): - mms_server_processes = list() - - for process in psutil.process_iter(): - if MMS_NAMESPACE in process.cmdline(): - mms_server_processes.append(process) - - if not mms_server_processes: - raise Exception("mms model server was unsuccessfully started") - - if len(mms_server_processes) > 1: - raise Exception("multiple mms model servers are not supported") - - return mms_server_processes[0] diff --git a/src/sagemaker_inference/model_server_utils.py b/src/sagemaker_inference/model_server_utils.py new file mode 100644 index 0000000..870d244 --- /dev/null +++ b/src/sagemaker_inference/model_server_utils.py @@ -0,0 +1,66 @@ +import os +import signal +import subprocess +import sys + +import pkg_resources +import psutil +from retrying import retry + +import sagemaker_inference +from sagemaker_inference import environment, logging, utils +from sagemaker_inference.environment import code_dir + +PYTHON_PATH_ENV = "PYTHONPATH" +logger = logging.get_logger() +REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt") + +def add_sigterm_handler(mms_process): + def _terminate(signo, frame): # pylint: disable=unused-argument + try: + os.kill(mms_process.pid, signal.SIGTERM) + except OSError: + pass + + signal.signal(signal.SIGTERM, _terminate) + + +def set_python_path(): + # MMS handles code execution by appending the export path, provided + # to the model archiver, to the PYTHONPATH env var. + # The code_dir has to be added to the PYTHONPATH otherwise the + # user provided module can not be imported properly. + code_dir_path = "{}:".format(environment.code_dir) + + if PYTHON_PATH_ENV in os.environ: + os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] + else: + os.environ[PYTHON_PATH_ENV] = code_dir_path + +def install_requirements(): + logger.info("installing packages from requirements.txt...") + pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH] + + try: + subprocess.check_call(pip_install_cmd) + except subprocess.CalledProcessError: + logger.error("failed to install required packages, exiting") + raise ValueError("failed to install required packages") + + +# retry for 10 seconds +@retry(stop_max_delay=10 * 1000) +def retrieve_model_server_process(namespace): + model_server_processes = list() + + for process in psutil.process_iter(): + if namespace in process.cmdline(): + model_server_processes.append(process) + + if not model_server_processes: + raise Exception("model server was unsuccessfully started") + + if len(model_server_processes) > 1: + raise Exception("multiple model servers are not supported") + + return model_server_processes[0] diff --git a/src/sagemaker_inference/torchserve.py b/src/sagemaker_inference/torchserve.py new file mode 100644 index 0000000..759bdbf --- /dev/null +++ b/src/sagemaker_inference/torchserve.py @@ -0,0 +1,163 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functionality to configure and start the +multi-model server.""" +from __future__ import absolute_import + +import os +import signal +import subprocess +import sys +import importlib + +import pkg_resources +import psutil +from retrying import retry + +import sagemaker_inference +from sagemaker_inference import default_handler_service, environment, logging, utils +from sagemaker_inference.model_server_utils import add_sigterm_handler, set_python_path, install_requirements, retrieve_model_server_process +from sagemaker_inference.environment import code_dir + +logger = logging.get_logger() + +TS_CONFIG_FILE = os.path.join("/etc", "sagemaker-ts.properties") +DEFAULT_TS_CONFIG_FILE = pkg_resources.resource_filename( + sagemaker_inference.__name__, "/etc/default-ts.properties" +) +MME_TS_CONFIG_FILE = pkg_resources.resource_filename( + sagemaker_inference.__name__, "/etc/mme-ts.properties" +) +DEFAULT_TS_LOG_FILE = pkg_resources.resource_filename( + sagemaker_inference.__name__, "/etc/ts.log4j.properties" +) +DEFAULT_TS_MODEL_DIRECTORY = os.path.join(os.getcwd(), ".sagemaker/ts/models") +DEFAULT_TS_MODEL_NAME = "model" +DEFAULT_TS_MODEL_SERIALIZED_FILE = "model.pth" +DEFAULT_TS_HANDLER_SERVICE = "sagemaker_pytorch_serving_container.handler_service" + +ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" +MODEL_STORE = "/" if ENABLE_MULTI_MODEL else DEFAULT_TS_MODEL_DIRECTORY + +PYTHON_PATH_ENV = "PYTHONPATH" +REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt") +TS_NAMESPACE = "org.pytorch.serve.ModelServer" + + +def start_model_server(handler_service=DEFAULT_TS_HANDLER_SERVICE): + """Configure and start the model server. + + Args: + handler_service (str): python path pointing to a module that defines + a class with the following: + + - A ``handle`` method, which is invoked for all incoming inference + requests to the model server. + - A ``initialize`` method, which is invoked at model server start up + for loading the model. + + Defaults to ``sagemaker_inference.default_handler_service``. + + """ + + if ENABLE_MULTI_MODEL: + if not os.getenv("SAGEMAKER_HANDLER"): + os.environ["SAGEMAKER_HANDLER"] = handler_service + set_python_path() + else: + _adapt_to_ts_format(handler_service) + + _create_torchserve_config_file() + + if os.path.exists(REQUIREMENTS_PATH): + install_requirements() + + ts_model_server_cmd = [ + "torchserve", + "--start", + "--model-store", + MODEL_STORE, + "--ts-config", + TS_CONFIG_FILE, + "--log-config", + DEFAULT_TS_LOG_FILE, + "--models", + "model.mar" + ] + + logger.info(ts_model_server_cmd) + subprocess.Popen(ts_model_server_cmd) + + ts_process = retrieve_model_server_process(TS_NAMESPACE) + + add_sigterm_handler(ts_process) + + ts_process.wait() + + +def _adapt_to_ts_format(handler_service): + if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY): + os.makedirs(DEFAULT_TS_MODEL_DIRECTORY) + + + model_archiver_cmd = [ + "torch-model-archiver", + "--model-name", + DEFAULT_TS_MODEL_NAME, + "--handler", + handler_service, + "--serialized-file", + os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE), + "--export-path", + DEFAULT_TS_MODEL_DIRECTORY, + "--extra-files", + os.path.join(environment.model_dir, environment.Environment().module_name + ".py"), + "--version", + "1", + ] + + logger.info(model_archiver_cmd) + subprocess.check_call(model_archiver_cmd) + + set_python_path() + + +def _create_torchserve_config_file(): + configuration_properties = _generate_ts_config_properties() + + utils.write_file(TS_CONFIG_FILE, configuration_properties) + + +def _generate_ts_config_properties(): + env = environment.Environment() + + user_defined_configuration = { + "default_response_timeout": env.model_server_timeout, + "default_workers_per_model": env.model_server_workers, + "inference_address": "http://0.0.0.0:{}".format(env.inference_http_port), + "management_address": "http://0.0.0.0:{}".format(env.management_http_port), + } + + custom_configuration = str() + + for key in user_defined_configuration: + value = user_defined_configuration.get(key) + if value: + custom_configuration += "{}={}\n".format(key, value) + + if ENABLE_MULTI_MODEL: + default_configuration = utils.read_file(MME_TS_CONFIG_FILE) + else: + default_configuration = utils.read_file(DEFAULT_TS_CONFIG_FILE) + + return default_configuration + custom_configuration diff --git a/test/unit/test_model_server.py b/test/unit/test_model_server.py index 3d0e6c9..3f5f200 100644 --- a/test/unit/test_model_server.py +++ b/test/unit/test_model_server.py @@ -18,7 +18,7 @@ from mock import Mock, patch import pytest -from sagemaker_inference import environment, model_server +from sagemaker_inference import environment, model_server, model_server_utils from sagemaker_inference.model_server import MMS_NAMESPACE, REQUIREMENTS_PATH PYTHON_PATH = "python_path" @@ -27,9 +27,9 @@ @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_inference.model_server._retrieve_mms_server_process") -@patch("sagemaker_inference.model_server._add_sigterm_handler") -@patch("sagemaker_inference.model_server._install_requirements") +@patch("sagemaker_inference.model_server.retrieve_model_server_process") +@patch("sagemaker_inference.model_server.add_sigterm_handler") +@patch("sagemaker_inference.model_server.install_requirements") @patch("os.path.exists", return_value=True) @patch("sagemaker_inference.model_server._create_model_server_config_file") @patch("sagemaker_inference.model_server._adapt_to_mms_format") @@ -67,8 +67,8 @@ def test_start_model_server_default_service_handler( @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_inference.model_server._retrieve_mms_server_process") -@patch("sagemaker_inference.model_server._add_sigterm_handler") +@patch("sagemaker_inference.model_server.retrieve_model_server_process") +@patch("sagemaker_inference.model_server.add_sigterm_handler") @patch("sagemaker_inference.model_server._create_model_server_config_file") @patch("sagemaker_inference.model_server._adapt_to_mms_format") def test_start_model_server_custom_handler_service( @@ -81,7 +81,7 @@ def test_start_model_server_custom_handler_service( adapt.assert_called_once_with(handler_service) -@patch("sagemaker_inference.model_server._set_python_path") +@patch("sagemaker_inference.model_server.set_python_path") @patch("subprocess.check_call") @patch("os.makedirs") @patch("os.path.exists", return_value=False) @@ -111,7 +111,7 @@ def test_adapt_to_mms_format(path_exists, make_dir, subprocess_check_call, set_p set_python_path.assert_called_once_with() -@patch("sagemaker_inference.model_server._set_python_path") +@patch("sagemaker_inference.model_server.set_python_path") @patch("subprocess.check_call") @patch("os.makedirs") @patch("os.path.exists", return_value=True) @@ -126,24 +126,6 @@ def test_adapt_to_mms_format_existing_path( make_dir.assert_not_called() -@patch.dict(os.environ, {model_server.PYTHON_PATH_ENV: PYTHON_PATH}, clear=True) -def test_set_existing_python_path(): - model_server._set_python_path() - - code_dir_path = "{}:".format(environment.code_dir) - - assert os.environ[model_server.PYTHON_PATH_ENV] == code_dir_path + PYTHON_PATH - - -@patch.dict(os.environ, {}, clear=True) -def test_new_python_path(): - model_server._set_python_path() - - code_dir_path = "{}:".format(environment.code_dir) - - assert os.environ[model_server.PYTHON_PATH_ENV] == code_dir_path - - @patch("sagemaker_inference.model_server._generate_mms_config_properties") @patch("sagemaker_inference.utils.write_file") def test_create_model_server_config_file(write_file, generate_mms_config_props): @@ -192,76 +174,3 @@ def test_generate_mms_config_properties_default_workers(env, read_file): assert mms_config_properties.startswith(DEFAULT_CONFIGURATION) assert workers not in mms_config_properties - - -@patch("signal.signal") -def test_add_sigterm_handler(signal_call): - mms = Mock() - - model_server._add_sigterm_handler(mms) - - mock_calls = signal_call.mock_calls - first_argument = mock_calls[0][1][0] - second_argument = mock_calls[0][1][1] - - assert len(mock_calls) == 1 - assert first_argument == signal.SIGTERM - assert isinstance(second_argument, types.FunctionType) - - -@patch("subprocess.check_call") -def test_install_requirements(check_call): - model_server._install_requirements() - - -@patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd")) -def test_install_requirements_installation_failed(check_call): - with pytest.raises(ValueError) as e: - model_server._install_requirements() - - assert "failed to install required packages" in str(e.value) - - -@patch("retrying.Retrying.should_reject", return_value=False) -@patch("psutil.process_iter") -def test_retrieve_mms_server_process(process_iter, retry): - server = Mock() - server.cmdline.return_value = MMS_NAMESPACE - - processes = list() - processes.append(server) - - process_iter.return_value = processes - - process = model_server._retrieve_mms_server_process() - - assert process == server - - -@patch("retrying.Retrying.should_reject", return_value=False) -@patch("psutil.process_iter", return_value=list()) -def test_retrieve_mms_server_process_no_server(process_iter, retry): - with pytest.raises(Exception) as e: - model_server._retrieve_mms_server_process() - - assert "mms model server was unsuccessfully started" in str(e.value) - - -@patch("retrying.Retrying.should_reject", return_value=False) -@patch("psutil.process_iter") -def test_retrieve_mms_server_process_too_many_servers(process_iter, retry): - server = Mock() - second_server = Mock() - server.cmdline.return_value = MMS_NAMESPACE - second_server.cmdline.return_value = MMS_NAMESPACE - - processes = list() - processes.append(server) - processes.append(second_server) - - process_iter.return_value = processes - - with pytest.raises(Exception) as e: - model_server._retrieve_mms_server_process() - - assert "multiple mms model servers are not supported" in str(e.value) diff --git a/test/unit/test_model_server_utils.py b/test/unit/test_model_server_utils.py new file mode 100644 index 0000000..9cb5198 --- /dev/null +++ b/test/unit/test_model_server_utils.py @@ -0,0 +1,115 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import signal +import subprocess +import types + +from mock import Mock, patch +import pytest + +from sagemaker_inference import environment, torchserve, model_server_utils +from sagemaker_inference.torchserve import TS_NAMESPACE, REQUIREMENTS_PATH + +PYTHON_PATH = "python_path" +DEFAULT_CONFIGURATION = "default_configuration" + +@patch.dict(os.environ, {torchserve.PYTHON_PATH_ENV: PYTHON_PATH}, clear=True) +def test_set_existing_python_path(): + torchserve.set_python_path() + + code_dir_path = "{}:".format(environment.code_dir) + + assert os.environ[torchserve.PYTHON_PATH_ENV] == code_dir_path + PYTHON_PATH + + +@patch.dict(os.environ, {}, clear=True) +def test_new_python_path(): + torchserve.set_python_path() + + code_dir_path = "{}:".format(environment.code_dir) + + assert os.environ[torchserve.PYTHON_PATH_ENV] == code_dir_path + + +@patch("signal.signal") +def testadd_sigterm_handler(signal_call): + ts = Mock() + + torchserve.add_sigterm_handler(ts) + + mock_calls = signal_call.mock_calls + first_argument = mock_calls[0][1][0] + second_argument = mock_calls[0][1][1] + + assert len(mock_calls) == 1 + assert first_argument == signal.SIGTERM + assert isinstance(second_argument, types.FunctionType) + + +@patch("subprocess.check_call") +def testinstall_requirements(check_call): + torchserve.install_requirements() + + +@patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd")) +def testinstall_requirements_installation_failed(check_call): + with pytest.raises(ValueError) as e: + torchserve.install_requirements() + + assert "failed to install required packages" in str(e.value) + + +@patch("retrying.Retrying.should_reject", return_value=False) +@patch("psutil.process_iter") +def test_retrieve_model_server_process(process_iter, retry): + server = Mock() + server.cmdline.return_value = TS_NAMESPACE + + processes = list() + processes.append(server) + + process_iter.return_value = processes + + process = model_server_utils.retrieve_model_server_process(TS_NAMESPACE) + + assert process == server + + +@patch("retrying.Retrying.should_reject", return_value=False) +@patch("psutil.process_iter", return_value=list()) +def test_retrieve_model_server_process_no_server(process_iter, retry): + with pytest.raises(Exception) as e: + model_server_utils.retrieve_model_server_process(TS_NAMESPACE) + + assert "model server was unsuccessfully started" in str(e.value) + + +@patch("retrying.Retrying.should_reject", return_value=False) +@patch("psutil.process_iter") +def test_retrieve_model_server_process_too_many_servers(process_iter, retry): + server = Mock() + second_server = Mock() + server.cmdline.return_value = TS_NAMESPACE + second_server.cmdline.return_value = TS_NAMESPACE + + processes = list() + processes.append(server) + processes.append(second_server) + + process_iter.return_value = processes + + with pytest.raises(Exception) as e: + torchserve.retrieve_model_server_process(TS_NAMESPACE) + + assert "multiple model servers are not supported" in str(e.value) diff --git a/test/unit/test_torchserve.py b/test/unit/test_torchserve.py new file mode 100644 index 0000000..1ed3010 --- /dev/null +++ b/test/unit/test_torchserve.py @@ -0,0 +1,184 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import signal +import subprocess +import types + +from mock import Mock, patch +import pytest + +from sagemaker_inference import environment, torchserve, model_server_utils +from sagemaker_inference.torchserve import TS_NAMESPACE, REQUIREMENTS_PATH + +PYTHON_PATH = "python_path" +DEFAULT_CONFIGURATION = "default_configuration" + + +@patch("subprocess.call") +@patch("subprocess.Popen") +@patch("sagemaker_inference.torchserve.retrieve_model_server_process") +@patch("sagemaker_inference.torchserve.add_sigterm_handler") +@patch("sagemaker_inference.torchserve.install_requirements") +@patch("os.path.exists", return_value=True) +@patch("sagemaker_inference.torchserve._create_torchserve_config_file") +@patch("sagemaker_inference.torchserve._adapt_to_ts_format") +def test_start_torchserve_default_service_handler( + adapt, + create_config, + exists, + install_requirements, + sigterm, + retrieve, + subprocess_popen, + subprocess_call, +): + torchserve.start_model_server() + + adapt.assert_called_once_with(torchserve.DEFAULT_TS_HANDLER_SERVICE) + create_config.assert_called_once_with() + exists.assert_called_once_with(REQUIREMENTS_PATH) + install_requirements.assert_called_once_with() + + ts_model_server_cmd = [ + "torchserve", + "--start", + "--model-store", + torchserve.MODEL_STORE, + "--ts-config", + torchserve.TS_CONFIG_FILE, + "--log-config", + torchserve.DEFAULT_TS_LOG_FILE, + "--models", + "model.mar" + ] + + subprocess_popen.assert_called_once_with(ts_model_server_cmd) + retrieve.assert_called_once_with(torchserve.TS_NAMESPACE) + sigterm.assert_called_once_with(retrieve.return_value) + + +@patch("subprocess.call") +@patch("subprocess.Popen") +@patch("sagemaker_inference.torchserve.retrieve_model_server_process") +@patch("sagemaker_inference.torchserve.add_sigterm_handler") +@patch("sagemaker_inference.torchserve._create_torchserve_config_file") +@patch("sagemaker_inference.torchserve._adapt_to_ts_format") +def test_start_torchserve_custom_handler_service( + adapt, create_config, sigterm, retrieve, subprocess_popen, subprocess_call +): + handler_service = Mock() + + torchserve.start_model_server(handler_service) + + adapt.assert_called_once_with(handler_service) + + +@patch("sagemaker_inference.torchserve.set_python_path") +@patch("subprocess.check_call") +@patch("os.makedirs") +@patch("os.path.exists", return_value=False) +def test_adapt_to_ts_format(path_exists, make_dir, subprocess_check_call, set_python_path): + handler_service = Mock() + + torchserve._adapt_to_ts_format(handler_service) + + path_exists.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY) + make_dir.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY) + + model_archiver_cmd = [ + "torch-model-archiver", + "--model-name", + torchserve.DEFAULT_TS_MODEL_NAME, + "--handler", + handler_service, + #importlib.import_module(DEFAULT_TS_HANDLER_SERVICE).__file__, + "--serialized-file", + os.path.join(environment.model_dir, torchserve.DEFAULT_TS_MODEL_SERIALIZED_FILE), + "--export-path", + torchserve.DEFAULT_TS_MODEL_DIRECTORY, + "--extra-files", + os.path.join(environment.model_dir, environment.Environment().module_name + ".py"), + "--version", + "1", + ] + + subprocess_check_call.assert_called_once_with(model_archiver_cmd) + subprocess_check_call.assert_called_once() + set_python_path.assert_called_once_with() + + +@patch("sagemaker_inference.torchserve.set_python_path") +@patch("subprocess.check_call") +@patch("os.makedirs") +@patch("os.path.exists", return_value=True) +def test_adapt_to_ts_format_existing_path( + path_exists, make_dir, subprocess_check_call, set_python_path +): + handler_service = Mock() + + torchserve._adapt_to_ts_format(handler_service) + + path_exists.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY) + make_dir.assert_not_called() + + +@patch("sagemaker_inference.torchserve._generate_ts_config_properties") +@patch("sagemaker_inference.utils.write_file") +def test_create_torchserve_config_file(write_file, generate_ts_config_props): + torchserve._create_torchserve_config_file() + + write_file.assert_called_once_with( + torchserve.TS_CONFIG_FILE, generate_ts_config_props.return_value + ) + + +@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) +@patch("sagemaker_inference.environment.Environment") +def test_generate_ts_config_properties(env, read_file): + torchserve_timeout = "model_server_timeout" + torchserve_workers = "model_server_workers" + http_port = "http_port" + + env.return_value.model_server_timeout = torchserve_timeout + env.return_value.model_server_workers = torchserve_workers + env.return_value.inference_http_port = http_port + + ts_config_properties = torchserve._generate_ts_config_properties() + + inference_address = "inference_address=http://0.0.0.0:{}\n".format(http_port) + server_timeout = "default_response_timeout={}\n".format(torchserve_timeout) + workers = "default_workers_per_model={}\n".format(torchserve_workers) + + read_file.assert_called_once_with(torchserve.DEFAULT_TS_CONFIG_FILE) + + assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) + assert inference_address in ts_config_properties + assert server_timeout in ts_config_properties + assert workers in ts_config_properties + + +@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) +@patch("sagemaker_inference.environment.Environment") +def test_generate_ts_config_properties_default_workers(env, read_file): + env.return_value.torchserve_workers = None + + ts_config_properties = torchserve._generate_ts_config_properties() + + workers = "default_workers_per_model={}".format(None) + + read_file.assert_called_once_with(torchserve.DEFAULT_TS_CONFIG_FILE) + + assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) + assert workers not in ts_config_properties +