diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 2c087adf81..b7dc5b1da5 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -95,6 +95,7 @@ def _is_jumpstart_model_id(self) -> bool: def _create_pre_trained_js_model(self) -> Type[Model]: """Placeholder docstring""" pysdk_model = JumpStartModel(self.model) + pysdk_model.sagemaker_session = self.sagemaker_session self._original_deploy = pysdk_model.deploy pysdk_model.deploy = self._js_builder_deploy_wrapper diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index ef25e6ff93..832e3a9258 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -133,7 +133,10 @@ def _create_tgi_model(self) -> Type[Model]: logger.info("Auto detected %s. Proceeding with the the deployment.", self.image_uri) pysdk_model = HuggingFaceModel( - image_uri=self.image_uri, env=self.env_vars, role=self.role_arn + image_uri=self.image_uri, + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, ) self._original_deploy = pysdk_model.deploy diff --git a/src/sagemaker/serve/model_server/djl_serving/prepare.py b/src/sagemaker/serve/model_server/djl_serving/prepare.py index cda21c93c3..386c5fb66e 100644 --- a/src/sagemaker/serve/model_server/djl_serving/prepare.py +++ b/src/sagemaker/serve/model_server/djl_serving/prepare.py @@ -14,14 +14,14 @@ from __future__ import absolute_import import shutil -import tarfile -import subprocess import json +import tarfile import logging from typing import List from pathlib import Path from sagemaker.utils import _tmpdir +from sagemaker.s3 import S3Downloader from sagemaker.djl_inference import DJLModel from sagemaker.djl_inference.model import _read_existing_serving_properties from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage @@ -34,27 +34,57 @@ def _has_serving_properties_file(code_dir: Path) -> bool: - """Placeholder Docstring""" + """Check for existing serving properties in the directory""" return code_dir.joinpath(_SERVING_PROPERTIES_FILE).is_file() -def _members(resources: object, depth: int): - """Placeholder Docstring""" - for member in resources.getmembers(): - member.path = member.path.split("/", depth)[-1] - yield member +def _move_to_code_dir(js_model_dir: str, code_dir: Path): + """Move DJL Jumpstart resources from model to code_dir""" + js_model_resources = Path(js_model_dir).joinpath("model") + for resource in js_model_resources.glob("*"): + try: + shutil.move(resource, code_dir) + except shutil.Error as e: + if "already exists" in str(e): + continue + + +def _extract_js_resource(js_model_dir: str, js_id: str): + """Uncompress the jumpstart resource""" + tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") + with tarfile.open(str(tmp_sourcedir)) as resources: + resources.extractall(path=js_model_dir) def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path): - """Placeholder Docstring""" + """Copy the associated JumpStart Resource into the code directory""" logger.info("Downloading JumpStart artifacts from S3...") - with _tmpdir(directory=str(code_dir)) as js_model_dir: - subprocess.run(["aws", "s3", "cp", model_data, js_model_dir]) - logger.info("Uncompressing JumpStart artifacts for faster loading...") - tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") - with tarfile.open(str(tmp_sourcedir)) as resources: - resources.extractall(path=code_dir, members=_members(resources, 1)) + s3_downloader = S3Downloader() + invalid_model_data_format = False + with _tmpdir(directory=str(code_dir)) as js_model_dir: + if isinstance(model_data, str): + if model_data.endswith(".tar.gz"): + logger.info("Uncompressing JumpStart artifacts for faster loading...") + s3_downloader.download(model_data, js_model_dir) + _extract_js_resource(js_model_dir, js_id) + else: + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data, js_model_dir) + elif ( + isinstance(model_data, dict) + and model_data.get("S3DataSource") + and model_data.get("S3DataSource").get("S3Uri") + ): + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), js_model_dir) + else: + invalid_model_data_format = True + if not invalid_model_data_format: + _move_to_code_dir(js_model_dir, code_dir) + + if invalid_model_data_format: + raise ValueError("JumpStart model data compression format is unsupported: %s", model_data) existing_properties = _read_existing_serving_properties(code_dir) config_json_file = code_dir.joinpath("config.json") @@ -70,7 +100,7 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path): def _generate_properties_file( model: DJLModel, code_dir: Path, overwrite_props_from_file: bool, manual_set_props: dict ): - """Placeholder Docstring""" + """Construct serving properties file taking into account of overrides or manual specs""" if _has_serving_properties_file(code_dir): existing_properties = _read_existing_serving_properties(code_dir) else: diff --git a/src/sagemaker/serve/model_server/tgi/prepare.py b/src/sagemaker/serve/model_server/tgi/prepare.py index 6159841ff3..fe1162e505 100644 --- a/src/sagemaker/serve/model_server/tgi/prepare.py +++ b/src/sagemaker/serve/model_server/tgi/prepare.py @@ -1,37 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. """Prepare TgiModel for Deployment""" from __future__ import absolute_import import tarfile -import subprocess import logging from typing import List from pathlib import Path from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage from sagemaker.utils import _tmpdir +from sagemaker.s3 import S3Downloader logger = logging.getLogger(__name__) +def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str): + """Uncompress the jumpstart resource""" + tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") + with tarfile.open(str(tmp_sourcedir)) as resources: + resources.extractall(path=code_dir) + + def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool: - """Placeholder Docstring""" + """Copy the associated JumpStart Resource into the code directory""" logger.info("Downloading JumpStart artifacts from S3...") - with _tmpdir(directory=str(code_dir)) as js_model_dir: - js_model_data_loc = model_data.get("S3DataSource").get("S3Uri") - # TODO: leave this check here until we are sure every js model has moved to uncompressed - if js_model_data_loc.endswith("tar.gz"): - subprocess.run(["aws", "s3", "cp", js_model_data_loc, js_model_dir]) + + s3_downloader = S3Downloader() + if isinstance(model_data, str): + if model_data.endswith(".tar.gz"): logger.info("Uncompressing JumpStart artifacts for faster loading...") - tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") - with tarfile.open(str(tmp_sourcedir)) as resources: - resources.extractall(path=code_dir) + with _tmpdir(directory=str(code_dir)) as js_model_dir: + s3_downloader.download(model_data, js_model_dir) + _extract_js_resource(js_model_dir, code_dir, js_id) else: - subprocess.run(["aws", "s3", "cp", js_model_data_loc, js_model_dir, "--recursive"]) + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data, code_dir) + elif ( + isinstance(model_data, dict) + and model_data.get("S3DataSource") + and model_data.get("S3DataSource").get("S3Uri") + ): + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), code_dir) + else: + raise ValueError("JumpStart model data compression format is unsupported: %s", model_data) + return True def _create_dir_structure(model_path: str) -> tuple: - """Placeholder Docstring""" + """Create the expected model directory structure for the TGI server""" model_path = Path(model_path) if not model_path.exists(): model_path.mkdir(parents=True) diff --git a/tests/unit/sagemaker/serve/builder/test_djl_builder.py b/tests/unit/sagemaker/serve/builder/test_djl_builder.py index 018dc3356e..ee52373dd7 100644 --- a/tests/unit/sagemaker/serve/builder/test_djl_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_djl_builder.py @@ -114,10 +114,12 @@ def test_build_deploy_for_djl_local_container( mode=Mode.LOCAL_CONTAINER, model_server=ModelServer.DJL_SERVING, ) + builder._prepare_for_mode = MagicMock() builder._prepare_for_mode.side_effect = None model = builder.build() + builder.serve_settings.telemetry_opt_out = True assert isinstance(model, HuggingFaceAccelerateModel) assert ( @@ -176,6 +178,7 @@ def test_build_for_djl_local_container_faster_transformer( model_server=ModelServer.DJL_SERVING, ) model = builder.build() + builder.serve_settings.telemetry_opt_out = True assert isinstance(model, FasterTransformerModel) assert ( @@ -211,6 +214,7 @@ def test_build_for_djl_local_container_deepspeed( model_server=ModelServer.DJL_SERVING, ) model = builder.build() + builder.serve_settings.telemetry_opt_out = True assert isinstance(model, DeepSpeedModel) assert model.generate_serving_properties() == mock_expected_deepspeed_serving_properties @@ -268,6 +272,7 @@ def test_tune_for_djl_local_container( builder._djl_model_builder_deploy_wrapper = MagicMock() model = builder.build() + builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() assert tuned_model.generate_serving_properties() == mock_most_performant_serving_properties @@ -317,6 +322,7 @@ def test_tune_for_djl_local_container_deep_ping_ex( builder._prepare_for_mode.side_effect = None model = builder.build() + builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() assert ( tuned_model.generate_serving_properties() @@ -369,6 +375,7 @@ def test_tune_for_djl_local_container_load_ex( builder._prepare_for_mode.side_effect = None model = builder.build() + builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() assert ( tuned_model.generate_serving_properties() @@ -421,6 +428,7 @@ def test_tune_for_djl_local_container_oom_ex( builder._prepare_for_mode.side_effect = None model = builder.build() + builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() assert ( tuned_model.generate_serving_properties() @@ -473,6 +481,7 @@ def test_tune_for_djl_local_container_invoke_ex( builder._prepare_for_mode.side_effect = None model = builder.build() + builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() assert ( tuned_model.generate_serving_properties() diff --git a/tests/unit/sagemaker/serve/model_server/constants.py b/tests/unit/sagemaker/serve/model_server/constants.py new file mode 100644 index 0000000000..41a0a832cb --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/constants.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +MOCK_MODEL_PATH = "/path/to/mock/model/dir" +MOCK_CODE_DIR = "/path/to/mock/model/dir/code" +MOCK_JUMPSTART_ID = "mock_llm_js_id" +MOCK_TMP_DIR = "tmp123456" +MOCK_COMPRESSED_MODEL_DATA_STR = ( + "s3://jumpstart-cache/to/infer-prepack-huggingface-llm-falcon-7b-bf16.tar.gz" +) +MOCK_UNCOMPRESSED_MODEL_DATA_STR = "s3://jumpstart-cache/to/artifacts/inference-prepack/v1.0.1/" +MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT = ( + "s3://jumpstart-cache/to/artifacts/inference-prepack/v1.0.1/dict/" +) +MOCK_UNCOMPRESSED_MODEL_DATA_DICT = { + "S3DataSource": { + "S3Uri": MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } +} +MOCK_INVALID_MODEL_DATA_DICT = {} diff --git a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py new file mode 100644 index 0000000000..40d3edb251 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py @@ -0,0 +1,275 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import Mock, PropertyMock, patch, mock_open, call + +from sagemaker.serve.model_server.djl_serving.prepare import ( + _copy_jumpstart_artifacts, + _create_dir_structure, + _move_to_code_dir, + _extract_js_resource, +) +from tests.unit.sagemaker.serve.model_server.constants import ( + MOCK_JUMPSTART_ID, + MOCK_TMP_DIR, + MOCK_COMPRESSED_MODEL_DATA_STR, + MOCK_UNCOMPRESSED_MODEL_DATA_STR, + MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, + MOCK_UNCOMPRESSED_MODEL_DATA_DICT, + MOCK_INVALID_MODEL_DATA_DICT, +) + +MOCK_DJL_JUMPSTART_GLOBED_RESOURCES = ["./inference.py", "./serving.properties", "./config.json"] + + +class DjlPrepareTests(TestCase): + @patch("sagemaker.serve.model_server.djl_serving.prepare._check_disk_space") + @patch("sagemaker.serve.model_server.djl_serving.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") + def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_disk_space): + mock_model_path = Mock() + mock_model_path.exists.return_value = False + mock_code_dir = Mock() + mock_model_path.joinpath.return_value = mock_code_dir + mock_path.return_value = mock_model_path + + ret_model_path, ret_code_dir = _create_dir_structure(mock_model_path) + + mock_model_path.mkdir.assert_called_once_with(parents=True) + mock_model_path.joinpath.assert_called_once_with("code") + mock_code_dir.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_disk_space.assert_called_once_with(mock_model_path) + mock_disk_usage.assert_called_once() + + self.assertEquals(ret_model_path, mock_model_path) + self.assertEquals(ret_code_dir, mock_code_dir) + + @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") + def test_create_dir_structure_invalid_path(self, mock_path): + mock_model_path = Mock() + mock_model_path.exists.return_value = True + mock_model_path.is_dir.return_value = False + mock_path.return_value = mock_model_path + + with self.assertRaises(ValueError) as context: + _create_dir_structure(mock_model_path) + + self.assertEquals("model_dir is not a valid directory", str(context.exception)) + + @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") + @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") + @patch( + "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", + return_value={}, + ) + @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", return_value={}) + def test_prepare_djl_js_resources_for_jumpstart_uncompressed_str( + self, + mock_load, + mock_open, + mock_move_to_code_dir, + mock_existing_props, + mock_tmpdir, + mock_s3_downloader, + ): + mock_code_dir = Mock() + mock_config_json_file = Mock() + mock_config_json_file.is_file.return_value = True + mock_code_dir.joinpath.return_value = mock_config_json_file + + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj + + mock_tmpdir_obj = Mock() + mock_js_dir = Mock() + mock_js_dir.return_value = MOCK_TMP_DIR + type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) + type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) + mock_tmpdir.return_value = mock_tmpdir_obj + + existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( + MOCK_UNCOMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir + ) + + mock_s3_downloader_obj.download.assert_called_once_with( + MOCK_UNCOMPRESSED_MODEL_DATA_STR, MOCK_TMP_DIR + ) + mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) + mock_code_dir.joinpath.assert_called_once_with("config.json") + self.assertEqual(existing_properties, {}) + self.assertEqual(hf_model_config, {}) + self.assertEqual(success, True) + + @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") + @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") + @patch( + "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", + return_value={}, + ) + @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", return_value={}) + def test_prepare_djl_js_resources_for_jumpstart_uncompressed_dict( + self, + mock_load, + mock_open, + mock_move_to_code_dir, + mock_existing_props, + mock_tmpdir, + mock_s3_downloader, + ): + mock_code_dir = Mock() + mock_config_json_file = Mock() + mock_config_json_file.is_file.return_value = True + mock_code_dir.joinpath.return_value = mock_config_json_file + + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj + + mock_tmpdir_obj = Mock() + mock_js_dir = Mock() + mock_js_dir.return_value = MOCK_TMP_DIR + type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) + type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) + mock_tmpdir.return_value = mock_tmpdir_obj + + existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( + MOCK_UNCOMPRESSED_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir + ) + + mock_s3_downloader_obj.download.assert_called_once_with( + MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, MOCK_TMP_DIR + ) + mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) + mock_code_dir.joinpath.assert_called_once_with("config.json") + self.assertEqual(existing_properties, {}) + self.assertEqual(hf_model_config, {}) + self.assertEqual(success, True) + + @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") + @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") + def test_prepare_djl_js_resources_for_jumpstart_invalid_model_data( + self, mock_move_to_code_dir, mock_tmpdir + ): + mock_code_dir = Mock() + mock_tmpdir_obj = Mock() + type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=Mock()) + type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) + mock_tmpdir.return_value = mock_tmpdir_obj + + with self.assertRaises(ValueError) as context: + _copy_jumpstart_artifacts( + MOCK_INVALID_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir + ) + + assert not mock_move_to_code_dir.called + self.assertTrue( + "JumpStart model data compression format is unsupported" in str(context.exception) + ) + + @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") + @patch("sagemaker.serve.model_server.djl_serving.prepare._extract_js_resource") + @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") + @patch( + "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", + return_value={}, + ) + @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", return_value={}) + def test_prepare_djl_js_resources_for_jumpstart_compressed_str( + self, + mock_load, + mock_open, + mock_move_to_code_dir, + mock_existing_props, + mock_tmpdir, + mock_extract_js_resource, + mock_s3_downloader, + ): + mock_code_dir = Mock() + mock_config_json_file = Mock() + mock_config_json_file.is_file.return_value = True + mock_code_dir.joinpath.return_value = mock_config_json_file + + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj + + mock_tmpdir_obj = Mock() + mock_js_dir = Mock() + mock_js_dir.return_value = MOCK_TMP_DIR + type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) + type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) + mock_tmpdir.return_value = mock_tmpdir_obj + + existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( + MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir + ) + + mock_s3_downloader_obj.download.assert_called_once_with( + MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_TMP_DIR + ) + mock_extract_js_resource.assert_called_with(MOCK_TMP_DIR, MOCK_JUMPSTART_ID) + mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) + mock_code_dir.joinpath.assert_called_once_with("config.json") + self.assertEqual(existing_properties, {}) + self.assertEqual(hf_model_config, {}) + self.assertEqual(success, True) + + @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") + @patch("sagemaker.serve.model_server.djl_serving.prepare.shutil") + def test_move_to_code_dir_success(self, mock_shutil, mock_path): + mock_path_obj = Mock() + mock_js_model_resources = Mock() + mock_js_model_resources.glob.return_value = MOCK_DJL_JUMPSTART_GLOBED_RESOURCES + mock_path_obj.joinpath.return_value = mock_js_model_resources + mock_path.return_value = mock_path_obj + + mock_js_model_dir = "" + mock_code_dir = Mock() + _move_to_code_dir(mock_js_model_dir, mock_code_dir) + + mock_path_obj.joinpath.assert_called_once_with("model") + + expected_moves = [ + call("./inference.py", mock_code_dir), + call("./serving.properties", mock_code_dir), + call("./config.json", mock_code_dir), + ] + mock_shutil.move.assert_has_calls(expected_moves) + + @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") + @patch("sagemaker.serve.model_server.djl_serving.prepare.tarfile") + def test_extract_js_resources_success(self, mock_tarfile, mock_path): + mock_path_obj = Mock() + mock_path_obj.joinpath.return_value = Mock() + mock_path.return_value = mock_path_obj + + mock_tar_obj = Mock() + mock_enter = Mock() + mock_resource_obj = Mock() + mock_enter.return_value = mock_resource_obj + type(mock_tar_obj).__enter__ = PropertyMock(return_value=mock_enter) + type(mock_tar_obj).__exit__ = PropertyMock(return_value=Mock()) + mock_tarfile.open.return_value = mock_tar_obj + + js_model_dir = "" + _extract_js_resource(js_model_dir, MOCK_JUMPSTART_ID) + + mock_path.assert_called_once_with(js_model_dir) + mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz") + mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir) diff --git a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py new file mode 100644 index 0000000000..c055be1f7d --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py @@ -0,0 +1,159 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import Mock, PropertyMock, patch + +from sagemaker.serve.model_server.tgi.prepare import ( + _create_dir_structure, + _copy_jumpstart_artifacts, + _extract_js_resource, +) +from tests.unit.sagemaker.serve.model_server.constants import ( + MOCK_JUMPSTART_ID, + MOCK_TMP_DIR, + MOCK_COMPRESSED_MODEL_DATA_STR, + MOCK_UNCOMPRESSED_MODEL_DATA_STR, + MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, + MOCK_UNCOMPRESSED_MODEL_DATA_DICT, + MOCK_INVALID_MODEL_DATA_DICT, +) + + +class TgiPrepareTests(TestCase): + @patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space") + @patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.tgi.prepare.Path") + def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_disk_space): + mock_model_path = Mock() + mock_model_path.exists.return_value = False + mock_code_dir = Mock() + mock_model_path.joinpath.return_value = mock_code_dir + mock_path.return_value = mock_model_path + + ret_model_path, ret_code_dir = _create_dir_structure(mock_model_path) + + mock_model_path.mkdir.assert_called_once_with(parents=True) + mock_model_path.joinpath.assert_called_once_with("code") + mock_code_dir.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_disk_space.assert_called_once_with(mock_model_path) + mock_disk_usage.assert_called_once() + + self.assertEquals(ret_model_path, mock_model_path) + self.assertEquals(ret_code_dir, mock_code_dir) + + @patch("sagemaker.serve.model_server.tgi.prepare.Path") + def test_create_dir_structure_invalid_path(self, mock_path): + mock_model_path = Mock() + mock_model_path.exists.return_value = True + mock_model_path.is_dir.return_value = False + mock_path.return_value = mock_model_path + + with self.assertRaises(ValueError) as context: + _create_dir_structure(mock_model_path) + + self.assertEquals("model_dir is not a valid directory", str(context.exception)) + + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") + def test_prepare_tgi_js_resources_for_jumpstart_uncompressed_str(self, mock_s3_downloader): + mock_code_dir = Mock() + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj + + _copy_jumpstart_artifacts( + MOCK_UNCOMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir + ) + + mock_s3_downloader_obj.download.assert_called_once_with( + MOCK_UNCOMPRESSED_MODEL_DATA_STR, mock_code_dir + ) + + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") + def test_prepare_tgi_js_resources_for_jumpstart_invalid_model_data(self, mock_s3_downloader): + mock_code_dir = Mock() + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj + + _copy_jumpstart_artifacts( + MOCK_UNCOMPRESSED_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir + ) + + mock_s3_downloader_obj.download.assert_called_once_with( + MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, mock_code_dir + ) + + def test_prepare_tgi_js_resources_for_jumpstart_invalid_format(self): + mock_code_dir = Mock() + + with self.assertRaises(ValueError) as context: + _copy_jumpstart_artifacts( + MOCK_INVALID_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir + ) + + self.assertTrue( + "JumpStart model data compression format is unsupported" in str(context.exception) + ) + + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") + @patch("sagemaker.serve.model_server.tgi.prepare._tmpdir") + @patch("sagemaker.serve.model_server.tgi.prepare._extract_js_resource") + def test_prepare_tgi_js_resources_for_jumpstart_compressed_str( + self, + mock_extract_js_resource, + mock_tmpdir, + mock_s3_downloader, + ): + mock_code_dir = Mock() + + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj + + mock_tmpdir_obj = Mock() + mock_js_dir = Mock() + mock_js_dir.return_value = MOCK_TMP_DIR + type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) + type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) + mock_tmpdir.return_value = mock_tmpdir_obj + + _copy_jumpstart_artifacts(MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir) + + mock_s3_downloader_obj.download.assert_called_once_with( + MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_TMP_DIR + ) + mock_extract_js_resource.assert_called_once_with( + MOCK_TMP_DIR, mock_code_dir, MOCK_JUMPSTART_ID + ) + + @patch("sagemaker.serve.model_server.tgi.prepare.Path") + @patch("sagemaker.serve.model_server.tgi.prepare.tarfile") + def test_extract_js_resources_success(self, mock_tarfile, mock_path): + mock_path_obj = Mock() + mock_path_obj.joinpath.return_value = Mock() + mock_path.return_value = mock_path_obj + + mock_tar_obj = Mock() + mock_enter = Mock() + mock_resource_obj = Mock() + mock_enter.return_value = mock_resource_obj + type(mock_tar_obj).__enter__ = PropertyMock(return_value=mock_enter) + type(mock_tar_obj).__exit__ = PropertyMock(return_value=Mock()) + mock_tarfile.open.return_value = mock_tar_obj + + js_model_dir = "" + code_dir = Mock() + _extract_js_resource(js_model_dir, code_dir, MOCK_JUMPSTART_ID) + + mock_path.assert_called_once_with(js_model_dir) + mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz") + mock_resource_obj.extractall.assert_called_once_with(path=code_dir)