diff --git a/setup.py b/setup.py index b53d482..477bec1 100644 --- a/setup.py +++ b/setup.py @@ -30,10 +30,10 @@ # We don't declare our dependency on transformers here because we build with # different packages for different variants -VERSION = "1.3.0" +VERSION = "1.3.1" install_requires = [ - "sagemaker-inference>=1.5.5", + "sagemaker-inference>=1.5.11", "huggingface_hub>=0.0.8", "retrying", "numpy", diff --git a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py index bb5482d..d949617 100644 --- a/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py +++ b/src/sagemaker_huggingface_inference_toolkit/mms_model_server.py @@ -17,7 +17,7 @@ import pathlib import subprocess -from sagemaker_inference import logging +from sagemaker_inference import environment, logging from sagemaker_inference.environment import model_dir from sagemaker_inference.model_server import ( DEFAULT_MMS_LOG_FILE, @@ -28,7 +28,7 @@ _add_sigterm_handler, _create_model_server_config_file, _install_requirements, - _retrieve_mms_server_process, + _retry_retrieve_mms_server_process, _set_python_path, ) @@ -84,7 +84,8 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): else: _adapt_to_mms_format(handler_service, model_dir) - _create_model_server_config_file() + env = environment.Environment() + _create_model_server_config_file(env) if os.path.exists(REQUIREMENTS_PATH): _install_requirements() @@ -102,7 +103,9 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE): logger.info(multi_model_server_cmd) subprocess.Popen(multi_model_server_cmd) - mms_process = _retrieve_mms_server_process() + # retry for configured timeout + mms_process = _retry_retrieve_mms_server_process(env.startup_timeout) + _add_sigterm_handler(mms_process) _add_sigchild_handler() diff --git a/tests/unit/test_mms_model_server.py b/tests/unit/test_mms_model_server.py index 9da7cb7..4221242 100644 --- a/tests/unit/test_mms_model_server.py +++ b/tests/unit/test_mms_model_server.py @@ -26,13 +26,15 @@ @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") @patch("os.path.exists", return_value=True) @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") +@patch("sagemaker_inference.environment.Environment") def test_start_mms_default_service_handler( + env, adapt, create_config, exists, @@ -42,10 +44,11 @@ def test_start_mms_default_service_handler( subprocess_popen, subprocess_call, ): + env.return_value.startup_timeout = 10000 mms_model_server.start_model_server() adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir) - create_config.assert_called_once_with() + create_config.assert_called_once_with(env.return_value) exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -67,7 +70,7 @@ def test_start_mms_default_service_handler( @patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True) @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") @@ -76,7 +79,9 @@ def test_start_mms_default_service_handler( @patch("os.path.exists", return_value=True) @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") +@patch("sagemaker_inference.environment.Environment") def test_start_mms_neuron( + env, adapt, create_config, exists, @@ -90,11 +95,11 @@ def test_start_mms_neuron( subprocess_call, is_aws_neuron_available, ): - + env.return_value.startup_timeout = 10000 mms_model_server.start_model_server() adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir) - create_config.assert_called_once_with() + create_config.assert_called_once_with(env.return_value) exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -115,7 +120,7 @@ def test_start_mms_neuron( @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements") @@ -124,7 +129,9 @@ def test_start_mms_neuron( @patch("os.path.exists", return_value=True) @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format") +@patch("sagemaker_inference.environment.Environment") def test_start_mms_with_model_from_hub( + env, adapt, create_config, exists, @@ -137,6 +144,8 @@ def test_start_mms_with_model_from_hub( subprocess_popen, subprocess_call, ): + env.return_value.startup_timeout = 10000 + os.environ["HF_MODEL_ID"] = "lysandre/tiny-bert-random" mms_model_server.start_model_server() @@ -149,7 +158,7 @@ def test_start_mms_with_model_from_hub( ) adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, load_model_from_hub()) - create_config.assert_called_once_with() + create_config.assert_called_once_with(env.return_value) exists.assert_called_with(mms_model_server.REQUIREMENTS_PATH) install_requirements.assert_called_once_with() @@ -172,7 +181,7 @@ def test_start_mms_with_model_from_hub( @patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True) @patch("subprocess.call") @patch("subprocess.Popen") -@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process") +@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler") @patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")