From 6085b5bee3cd22c641f3c86181e2e5784389d4ce Mon Sep 17 00:00:00 2001 From: philschmid Date: Thu, 10 Feb 2022 09:20:39 +0100 Subject: [PATCH 1/2] fixes configurable startup timeout --- setup.py | 2 +- .../mms_model_server.py | 11 +++++--- tests/unit/test_mms_model_server.py | 25 +++++++++++++------ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index b53d482..102378e 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ VERSION = "1.3.0" install_requires = [ - "sagemaker-inference>=1.5.5", + "sagemaker-inference>=1.5.11", "huggingface_hub>=0.0.8", "retrying", "numpy", diff --git a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py index bb5482d..b60f376 100644 --- a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py +++ b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py @@ -17,7 +17,7 @@ import pathlib import subprocess -from sagemaker_inference import logging +from sagemaker_inference import logging, environment from sagemaker_inference.environment import model_dir from sagemaker_inference.model_server import ( DEFAULT_MMS_LOG_FILE, @@ -28,7 +28,7 @@ _add_sigterm_handler, _create_model_server_config_file, _install_requirements, - _retrieve_mms_server_process, + _retry_retrieve_mms_server_process, _set_python_path, ) @@ -84,7 +84,8 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): else: _adapt_to_mms_format(handler_service, model_dir) - _create_model_server_config_file() + env = environment.Environment() + _create_model_server_config_file(env) if os.path.exists(REQUIREMENTS_PATH): _install_requirements() @@ -102,7 +103,9 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): logger.info(multi_model_server_cmd) subprocess.Popen(multi_model_server_cmd) - mms_process = _retrieve_mms_server_process() + # retry for configured timeout + mms_process = _retry_retrieve_mms_server_process(env.startup_timeout) + _add_sigterm_handler(mms_process) _add_sigchild_handler() diff --git a/tests/unit/test_mms_model_server.py b/tests/unit/test_mms_model_server.py index 9da7cb7..4221242 100644 --- a/tests/unit/test_mms_model_server.py +++ b/tests/unit/test_mms_model_server.py @@ -26,13 +26,15 @@ @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") @patch("os.path.exists", return_value=True) @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") +@patch("sagemaker_inference.environment.Environment") def test_start_mms_default_service_handler( + env, adapt, create_config, exists, @@ -42,10 +44,11 @@ def test_start_mms_default_service_handler( subprocess_popen, subprocess_call, ): + env.return_value.startup_timeout = 10000 mms_model_server.start_model_server() adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir) - create_config.assert_called_once_with() + create_config.assert_called_once_with(env.return_value) exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -67,7 +70,7 @@ def test_start_mms_default_service_handler( @patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True) @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") @@ -76,7 +79,9 @@ def test_start_mms_default_service_handler( @patch("os.path.exists", return_value=True) @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") +@patch("sagemaker_inference.environment.Environment") def test_start_mms_neuron( + env, adapt, create_config, exists, @@ -90,11 +95,11 @@ def test_start_mms_neuron( subprocess_call, is_aws_neuron_available, ): - + env.return_value.startup_timeout = 10000 mms_model_server.start_model_server() adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir) - create_config.assert_called_once_with() + create_config.assert_called_once_with(env.return_value) exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -115,7 +120,7 @@ def test_start_mms_neuron( @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") @@ -124,7 +129,9 @@ def test_start_mms_neuron( @patch("os.path.exists", return_value=True) @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") +@patch("sagemaker_inference.environment.Environment") def test_start_mms_with_model_from_hub( + env, adapt, create_config, exists, @@ -137,6 +144,8 @@ def test_start_mms_with_model_from_hub( subprocess_popen, subprocess_call, ): + env.return_value.startup_timeout = 10000 + os.environ["HF_MODEL_ID"] = "lysandre/tiny-bert-random" mms_model_server.start_model_server() @@ -149,7 +158,7 @@ def test_start_mms_with_model_from_hub( ) adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, load_model_from_hub()) - create_config.assert_called_once_with() + create_config.assert_called_once_with(env.return_value) exists.assert_called_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -172,7 +181,7 @@ def test_start_mms_with_model_from_hub( @patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True) @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") From 7f38612571f44e0baf4f8f851dcb155e0fb967ef Mon Sep 17 00:00:00 2001 From: philschmid Date: Thu, 10 Feb 2022 09:27:59 +0100 Subject: [PATCH 2/2] changed version --- setup.py | 2 +- src/sagemaker_huggingface_inference_toolkit/mms_model_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 102378e..477bec1 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ # We don't declare our dependency on transformers here because we build with # different packages for different variants -VERSION = "1.3.0" +VERSION = "1.3.1" install_requires = [ "sagemaker-inference>=1.5.11", diff --git a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py index b60f376..d949617 100644 --- a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py +++ b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py @@ -17,7 +17,7 @@ import pathlib import subprocess -from sagemaker_inference import logging, environment +from sagemaker_inference import environment, logging from sagemaker_inference.environment import model_dir from sagemaker_inference.model_server import ( DEFAULT_MMS_LOG_FILE,