diff --git a/src/sagemaker_pytorch_serving_container/serving.py b/src/sagemaker_pytorch_serving_container/serving.py index 5e70c961..608fbf5d 100644 --- a/src/sagemaker_pytorch_serving_container/serving.py +++ b/src/sagemaker_pytorch_serving_container/serving.py @@ -12,6 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os,sys +import subprocess from subprocess import CalledProcessError from retrying import retry @@ -20,11 +22,49 @@ HANDLER_SERVICE = handler_service.__file__ +## added logging function to configure log4j2 loglevel. +def configure_logging(): + log_levels = { + '0': 'off', + '10': 'fatal', + '20': 'error', + '30': 'warn', + '40': 'info', + '50': 'debug', + '60': 'trace' + } + + # Get the directory of the current script + current_script_path = os.path.abspath(__file__) + + # Construct the path to log4j2.xml relative to the script location + log4j2_path = os.path.join(os.path.dirname(current_script_path), 'etc', 'log4j2.xml') + + print(f"Current script path: {current_script_path}") + print(f"log4j2.xml path: {log4j2_path}") + + if not os.path.exists(log4j2_path): + print(f"Error: {log4j2_path} does not exist", file=sys.stderr) + return + + ts_log_level = os.environ.get('TS_LOG_LEVEL') + + if ts_log_level is not None: + if ts_log_level in log_levels: + try: + log_level = log_levels[ts_log_level] + subprocess.run(['sed', '-i', f's/info/{log_level}/g', log4j2_path], check=True) + print(f"Logging level set to {log_level}") + except subprocess.CalledProcessError as e: + print(f"Error configuring the logging: {e}", file=sys.stderr) + else: + print(f"Invalid TS_LOG_LEVEL value: {ts_log_level}. No changes made to logging configuration.", file=sys.stderr) + else: + print("TS_LOG_LEVEL not set. Using default logging configuration.") def _retry_if_error(exception): return isinstance(exception, CalledProcessError) - @retry(stop_max_delay=1000 * 30, retry_on_exception=_retry_if_error) def _start_torchserve(): @@ -33,6 +73,9 @@ def _start_torchserve(): # retry starting mms until it's ready torchserve.start_torchserve(handler_service=HANDLER_SERVICE) - def main(): + configure_logging() _start_torchserve() + +if __name__ == '__main__': + main() diff --git a/test/unit/test_log_config.py b/test/unit/test_log_config.py new file mode 100644 index 00000000..096b7f1e --- /dev/null +++ b/test/unit/test_log_config.py @@ -0,0 +1,68 @@ +import unittest +from unittest.mock import patch, MagicMock +import os, subprocess +import sys +import io + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))) + +from sagemaker_pytorch_serving_container.serving import configure_logging + +class TestLogConfig(unittest.TestCase): + @patch('os.path.exists') + @patch('os.environ.get') + @patch('subprocess.run') + def test_valid_log_level(self, mock_run, mock_env_get, mock_exists): + mock_exists.return_value = True + mock_env_get.return_value = '20' + mock_run.return_value = MagicMock(returncode=0) + + with patch('sys.stdout', new=io.StringIO()) as fake_out: + configure_logging() + self.assertIn("Logging level set to error", fake_out.getvalue()) + + mock_run.assert_called_once() + + @patch('os.path.exists') + @patch('os.environ.get') + def test_invalid_log_level(self, mock_env_get, mock_exists): + mock_exists.return_value = True + mock_env_get.return_value = '70' + + with patch('sys.stderr', new=io.StringIO()) as fake_err: + configure_logging() + self.assertIn("Invalid TS_LOG_LEVEL value: 70", fake_err.getvalue()) + + @patch('os.path.exists') + @patch('os.environ.get') + def test_no_log_level_set(self, mock_env_get, mock_exists): + mock_exists.return_value = True + mock_env_get.return_value = None + + with patch('sys.stdout', new=io.StringIO()) as fake_out: + configure_logging() + self.assertIn("TS_LOG_LEVEL not set", fake_out.getvalue()) + + @patch('os.path.exists') + @patch('os.environ.get') + @patch('subprocess.run') + def test_subprocess_error(self, mock_run, mock_env_get, mock_exists): + mock_exists.return_value = True + mock_env_get.return_value = '20' + mock_run.side_effect = subprocess.CalledProcessError(1, 'sed') + + with patch('sys.stderr', new=io.StringIO()) as fake_err: + configure_logging() + self.assertIn("Error configuring the logging", fake_err.getvalue()) + + @patch('os.path.exists') + def test_log4j2_file_not_found(self, mock_exists): + mock_exists.return_value = False + + with patch('sys.stderr', new=io.StringIO()) as fake_err: + configure_logging() + self.assertIn("does not exist", fake_err.getvalue()) + +if __name__ == '__main__': + unittest.main()