|
18 | 18 |
|
19 | 19 | from sagemaker.serve.builder.model_builder import ModelBuilder
|
20 | 20 | from sagemaker.serve.mode.function_pointers import Mode
|
| 21 | +from sagemaker.serve.utils import task |
| 22 | +from sagemaker.serve.utils.exceptions import TaskNotFoundException |
21 | 23 | from sagemaker.serve.utils.types import ModelServer
|
22 | 24 | from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
|
23 | 25 |
|
@@ -985,3 +987,93 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
|
985 | 987 | build_result.deploy(mode=Mode.LOCAL_CONTAINER)
|
986 | 988 |
|
987 | 989 | self.assertEqual(builder.mode, Mode.LOCAL_CONTAINER)
|
| 990 | + |
| 991 | + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") |
| 992 | + @patch("sagemaker.image_uris.retrieve") |
| 993 | + @patch("sagemaker.djl_inference.model.urllib") |
| 994 | + @patch("sagemaker.djl_inference.model.json") |
| 995 | + @patch("sagemaker.huggingface.llm_utils.urllib") |
| 996 | + @patch("sagemaker.huggingface.llm_utils.json") |
| 997 | + @patch("sagemaker.model_uris.retrieve") |
| 998 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 999 | + def test_build_happy_path_when_schema_builder_not_present( |
| 1000 | + self, |
| 1001 | + mock_serveSettings, |
| 1002 | + mock_model_uris_retrieve, |
| 1003 | + mock_llm_utils_json, |
| 1004 | + mock_llm_utils_urllib, |
| 1005 | + mock_model_json, |
| 1006 | + mock_model_urllib, |
| 1007 | + mock_image_uris_retrieve, |
| 1008 | + mock_hf_model, |
| 1009 | + ): |
| 1010 | + # Setup mocks |
| 1011 | + |
| 1012 | + mock_setting_object = mock_serveSettings.return_value |
| 1013 | + mock_setting_object.role_arn = mock_role_arn |
| 1014 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 1015 | + |
| 1016 | + # HF Pipeline Tag |
| 1017 | + mock_model_uris_retrieve.side_effect = KeyError |
| 1018 | + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"} |
| 1019 | + mock_llm_utils_urllib.request.Request.side_effect = Mock() |
| 1020 | + |
| 1021 | + # HF Model config |
| 1022 | + mock_model_json.load.return_value = {"some": "config"} |
| 1023 | + mock_model_urllib.request.Request.side_effect = Mock() |
| 1024 | + |
| 1025 | + mock_image_uris_retrieve.return_value = "https://some-image-uri" |
| 1026 | + |
| 1027 | + model_builder = ModelBuilder(model="meta-llama/Llama-2-7b-hf") |
| 1028 | + model_builder.build(sagemaker_session=mock_session) |
| 1029 | + |
| 1030 | + self.assertIsNotNone(model_builder.schema_builder) |
| 1031 | + sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation") |
| 1032 | + self.assertEqual( |
| 1033 | + sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"] |
| 1034 | + ) |
| 1035 | + self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) |
| 1036 | + |
| 1037 | + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") |
| 1038 | + @patch("sagemaker.image_uris.retrieve") |
| 1039 | + @patch("sagemaker.djl_inference.model.urllib") |
| 1040 | + @patch("sagemaker.djl_inference.model.json") |
| 1041 | + @patch("sagemaker.huggingface.llm_utils.urllib") |
| 1042 | + @patch("sagemaker.huggingface.llm_utils.json") |
| 1043 | + @patch("sagemaker.model_uris.retrieve") |
| 1044 | + @patch("sagemaker.serve.builder.model_builder._ServeSettings") |
| 1045 | + def test_build_negative_path_when_schema_builder_not_present( |
| 1046 | + self, |
| 1047 | + mock_serveSettings, |
| 1048 | + mock_model_uris_retrieve, |
| 1049 | + mock_llm_utils_json, |
| 1050 | + mock_llm_utils_urllib, |
| 1051 | + mock_model_json, |
| 1052 | + mock_model_urllib, |
| 1053 | + mock_image_uris_retrieve, |
| 1054 | + mock_hf_model, |
| 1055 | + ): |
| 1056 | + # Setup mocks |
| 1057 | + |
| 1058 | + mock_setting_object = mock_serveSettings.return_value |
| 1059 | + mock_setting_object.role_arn = mock_role_arn |
| 1060 | + mock_setting_object.s3_model_data_url = mock_s3_model_data_url |
| 1061 | + |
| 1062 | + # HF Pipeline Tag |
| 1063 | + mock_model_uris_retrieve.side_effect = KeyError |
| 1064 | + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-to-image"} |
| 1065 | + mock_llm_utils_urllib.request.Request.side_effect = Mock() |
| 1066 | + |
| 1067 | + # HF Model config |
| 1068 | + mock_model_json.load.return_value = {"some": "config"} |
| 1069 | + mock_model_urllib.request.Request.side_effect = Mock() |
| 1070 | + |
| 1071 | + mock_image_uris_retrieve.return_value = "https://some-image-uri" |
| 1072 | + |
| 1073 | + model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") |
| 1074 | + |
| 1075 | + self.assertRaisesRegexp( |
| 1076 | + TaskNotFoundException, |
| 1077 | + "Error Message: Schema builder for text-to-image could not be found.", |
| 1078 | + lambda: model_builder.build(sagemaker_session=mock_session), |
| 1079 | + ) |
0 commit comments