diff --git a/MANIFEST.in b/MANIFEST.in index f8d3d426f6..c5eeeed043 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ recursive-include src/sagemaker *.py include src/sagemaker/image_uri_config/*.json +include src/sagemaker/serve/schema/*.json include src/sagemaker/serve/requirements.txt recursive-include requirements * diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index c7f2dc5633..c63473d8d0 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -38,6 +38,8 @@ from sagemaker.predictor import Predictor from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.utils import task +from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.detector.image_detector import ( auto_detect_container, @@ -616,7 +618,12 @@ def build( hf_model_md = get_huggingface_model_metadata( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705 + + model_task = hf_model_md.get("pipeline_tag") + if self.schema_builder is None and model_task: + self._schema_builder_init(model_task) + + if model_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() else: return self._build_for_transformers() @@ -674,3 +681,18 @@ def validate(self, model_dir: str) -> Type[bool]: """ return get_metadata(model_dir) + + def _schema_builder_init(self, model_task: str): + """Initialize the schema builder + + Args: + model_task (str): Required, the task name + + Raises: + TaskNotFoundException: If the I/O schema for the given task is not found. + """ + try: + sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task) + self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) + except ValueError: + raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.") diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json new file mode 100644 index 0000000000..9ee6d186a2 --- /dev/null +++ b/src/sagemaker/serve/schema/task.json @@ -0,0 +1,67 @@ +{ + "fill-mask": { + "sample_inputs": { + "properties": { + "inputs": "Paris is the of France.", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "sequence": "Paris is the capital of France.", + "score": 0.7 + } + ] + } + }, + "question-answering": { + "sample_inputs": { + "properties": { + "context": "I have a German Shepherd dog, named Coco.", + "question": "What is my dog's breed?" + } + }, + "sample_outputs": { + "properties": [ + { + "answer": "German Shepherd", + "score": 0.972, + "start": 9, + "end": 24 + } + ] + } + }, + "text-classification": { + "sample_inputs": { + "properties": { + "inputs": "Where is the capital of France?, Paris is the capital of France.", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "label": "entailment", + "score": 0.997 + } + ] + } + }, + "text-generation": { + "sample_inputs": { + "properties": { + "inputs": "Hello, I'm a language model", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" + } + ] + } + } +} diff --git a/src/sagemaker/serve/utils/exceptions.py b/src/sagemaker/serve/utils/exceptions.py index f14677d93a..8132820cc0 100644 --- a/src/sagemaker/serve/utils/exceptions.py +++ b/src/sagemaker/serve/utils/exceptions.py @@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException): def __init__(self, message): super().__init__(message=message) + + +class TaskNotFoundException(ModelBuilderException): + """Raise when HuggingFace task could not be found""" + + fmt = "Error Message: {message}" + + def __init__(self, message): + super().__init__(message=message) diff --git a/src/sagemaker/serve/utils/task.py b/src/sagemaker/serve/utils/task.py new file mode 100644 index 0000000000..6f8786985c --- /dev/null +++ b/src/sagemaker/serve/utils/task.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Accessors to retrieve task fallback input/output schema""" +from __future__ import absolute_import + +import json +import os +from typing import Any, Tuple + + +def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: + """Retrieves task sample inputs and outputs locally. + + Args: + task (str): Required, the task name + + Returns: + Tuple[Any, Any]: A tuple that contains the sample input, + at index 0, and output schema, at index 1. + + Raises: + ValueError: If no tasks config found or the task does not exist in the local config. + """ + config_dir = os.path.dirname(os.path.dirname(__file__)) + task_io_config_path = os.path.join(config_dir, "schema", "task.json") + try: + with open(task_io_config_path) as f: + task_io_config = json.load(f) + task_io_schemas = task_io_config.get(task, None) + + if task_io_schemas is None: + raise ValueError(f"Could not find {task} I/O schema.") + + sample_schema = ( + task_io_schemas["sample_inputs"]["properties"], + task_io_schemas["sample_outputs"]["properties"], + ) + return sample_schema + except FileNotFoundError: + raise ValueError("Could not find tasks config file.") diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py new file mode 100644 index 0000000000..3816985d8f --- /dev/null +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.utils import task + +import pytest + +from sagemaker.serve.utils.exceptions import TaskNotFoundException +from tests.integ.sagemaker.serve.constants import ( + PYTHON_VERSION_IS_NOT_310, + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, +) + +from tests.integ.timeout import timeout +from tests.integ.utils import cleanup_model_resources + +import logging + +logger = logging.getLogger(__name__) + + +def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session): + model_builder = ModelBuilder(model="bert-base-uncased") + + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + inputs, outputs = task.retrieve_local_schemas("fill-mask") + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Schema Builder Simplification feature", +) +def test_model_builder_happy_path_with_only_model_id_question_answering( + sagemaker_session, gpu_instance_type +): + model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad") + + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + inputs, outputs = task.retrieve_local_schemas("question-answering") + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + caught_ex = None + try: + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy( + role=role_arn, instance_count=1, instance_type=gpu_instance_type + ) + + predicted_outputs = predictor.predict(inputs) + assert predicted_outputs is not None + + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"{caught_ex} was thrown when running transformers sagemaker endpoint test" + + +def test_model_builder_negative_path(sagemaker_session): + model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") + + with pytest.raises( + TaskNotFoundException, + match="Error Message: Schema builder for text-to-image could not be found.", + ): + model_builder.build(sagemaker_session=sagemaker_session) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 898536c03f..becf63ab41 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -18,6 +18,8 @@ from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils import task +from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.types import ModelServer from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG @@ -985,3 +987,93 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co build_result.deploy(mode=Mode.LOCAL_CONTAINER) self.assertEqual(builder.mode, Mode.LOCAL_CONTAINER) + + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_happy_path_when_schema_builder_not_present( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_hf_model, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="meta-llama/Llama-2-7b-hf") + model_builder.build(sagemaker_session=mock_session) + + self.assertIsNotNone(model_builder.schema_builder) + sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation") + self.assertEqual( + sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"] + ) + self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) + + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_negative_path_when_schema_builder_not_present( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_hf_model, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-to-image"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") + + self.assertRaisesRegexp( + TaskNotFoundException, + "Error Message: Schema builder for text-to-image could not be found.", + lambda: model_builder.build(sagemaker_session=mock_session), + ) diff --git a/tests/unit/sagemaker/serve/utils/test_task.py b/tests/unit/sagemaker/serve/utils/test_task.py new file mode 100644 index 0000000000..78553968e1 --- /dev/null +++ b/tests/unit/sagemaker/serve/utils/test_task.py @@ -0,0 +1,49 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch + +import pytest + +from sagemaker.serve.utils import task + +EXPECTED_INPUTS = {"inputs": "Paris is the of France.", "parameters": {}} +EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}] +HF_INVALID_TASK = "not-present-task" + + +def test_retrieve_local_schemas_success(): + inputs, outputs = task.retrieve_local_schemas("fill-mask") + + assert inputs == EXPECTED_INPUTS + assert outputs == EXPECTED_OUTPUTS + + +def test_retrieve_local_schemas_text_generation_success(): + inputs, outputs = task.retrieve_local_schemas("text-generation") + + assert inputs is not None + assert outputs is not None + + +def test_retrieve_local_schemas_throws(): + with pytest.raises(ValueError, match=f"Could not find {HF_INVALID_TASK} I/O schema."): + task.retrieve_local_schemas(HF_INVALID_TASK) + + +@patch("builtins.open") +def test_retrieve_local_schemas_file_not_found(mock_open): + mock_open.side_effect = FileNotFoundError + with pytest.raises(ValueError, match="Could not find tasks config file."): + task.retrieve_local_schemas(HF_INVALID_TASK)