From fd3e86b19f091dc1ce4d7906644107e08768c6a4 Mon Sep 17 00:00:00 2001 From: Samrudhi Sharma Date: Tue, 20 Feb 2024 15:12:24 -0800 Subject: [PATCH 01/11] feat: Add Optional task to Model --- src/sagemaker/model.py | 6 +- src/sagemaker/serve/builder/model_builder.py | 1 + .../serve/test_serve_transformers.py | 56 +++++++++++++++++++ tests/unit/sagemaker/model/test_deploy.py | 55 ++++++++++++++++++ 4 files changed, 117 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ff340b58e9..36b49f1ac1 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -156,6 +156,7 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, + task: Optional[Union[str, PipelineVariable]] = None, ): """Initialize an SageMaker ``Model``. @@ -319,7 +320,9 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). - + task (str or PipelineVariable): Task values used to override the HuggingFace task + Examples are: "audio-classification", "depth-estimation", + "feature-extraction" etc. (default: None). """ self.model_data = model_data self.image_uri = image_uri @@ -396,6 +399,7 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None + self.task = task @runnable_by_pipeline def register( diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 0ade8096f6..d5b28e5803 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -203,6 +203,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): "help": ( 'Model object with "predict" method to perform inference ' "or HuggingFace/JumpStart Model ID" + "or HuggingFace Task to override" ) }, ) diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 735f60d0f2..958c469159 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import pytest +from sagemaker.model import Model from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder, Mode @@ -28,6 +29,11 @@ logger = logging.getLogger(__name__) +MODEL_DATA = "s3://bucket/model.tar.gz" +MODEL_IMAGE = "mi" +ROLE = "some-role" +HF_TASK = "fill-mask" + sample_input = { "inputs": "The man worked as a [MASK].", } @@ -80,6 +86,17 @@ def model_builder_model_schema_builder(): ) +@pytest.fixture +def model_builder_model_with_task_builder(): + model = Model( + MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name="bert-base-uncased", role=ROLE + ) + return ModelBuilder( + model_path=HF_DIR, + model=model, + ) + + @pytest.fixture def model_builder(request): return request.getfixturevalue(request.param) @@ -122,3 +139,42 @@ def test_pytorch_transformers_sagemaker_endpoint( assert ( False ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Optional task", +) +@pytest.mark.parametrize("model_builder", ["model_builder_model_with_task_builder"], indirect=True) +def test_happy_path_with_task_sagemaker_endpoint( + sagemaker_session, model_builder, gpu_instance_type, input +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + model = model_builder.build( + mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session + ) + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1) + logger.info("Endpoint successfully deployed.") + predictor.predict(input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 953cbe775c..ef2f0d4516 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -85,6 +85,7 @@ }, limits={}, ) +HF_TASK = "audio-classification" @pytest.fixture @@ -1027,3 +1028,57 @@ def test_deploy_with_name_and_resources(sagemaker_session): async_inference_config_dict=None, live_logging=False, ) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_name_and_task(sagemaker_session): + sagemaker_session.sagemaker_config = {} + + model = Model( + MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session + ) + + endpoint_name = "testing-task-input" + predictor = model.deploy( + endpoint_name=endpoint_name, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + ) + + sagemaker_session.create_model.assert_called_with( + name=MODEL_IMAGE, + role=ROLE, + task=HF_TASK + ) + + assert isinstance(predictor, sagemaker.predictor.Predictor) + assert predictor.endpoint_name == endpoint_name + assert predictor.sagemaker_session == sagemaker_session + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_name_and_without_task(sagemaker_session): + sagemaker_session.sagemaker_config = {} + + model = Model( + MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session + ) + + endpoint_name = "testing-without-task-input" + predictor = model.deploy( + endpoint_name=endpoint_name, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + ) + + sagemaker_session.create_model.assert_called_with( + name=MODEL_IMAGE, + role=ROLE, + task=None, + ) + + assert isinstance(predictor, sagemaker.predictor.Predictor) + assert predictor.endpoint_name == endpoint_name + assert predictor.sagemaker_session == sagemaker_session From 9e4e2ecece7652b8cb6113b822fc8e0289ecb198 Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Wed, 21 Feb 2024 22:17:01 +0000 Subject: [PATCH 02/11] Revert "feat: Add Optional task to Model" This reverts commit fd3e86b19f091dc1ce4d7906644107e08768c6a4. --- src/sagemaker/model.py | 6 +- src/sagemaker/serve/builder/model_builder.py | 1 - .../serve/test_serve_transformers.py | 56 ------------------- tests/unit/sagemaker/model/test_deploy.py | 55 ------------------ 4 files changed, 1 insertion(+), 117 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 36b49f1ac1..ff340b58e9 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -156,7 +156,6 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, - task: Optional[Union[str, PipelineVariable]] = None, ): """Initialize an SageMaker ``Model``. @@ -320,9 +319,7 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). - task (str or PipelineVariable): Task values used to override the HuggingFace task - Examples are: "audio-classification", "depth-estimation", - "feature-extraction" etc. (default: None). + """ self.model_data = model_data self.image_uri = image_uri @@ -399,7 +396,6 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None - self.task = task @runnable_by_pipeline def register( diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 43060aed9a..8ca6a5d4ab 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -205,7 +205,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): "help": ( 'Model object with "predict" method to perform inference ' "or HuggingFace/JumpStart Model ID" - "or HuggingFace Task to override" ) }, ) diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 958c469159..735f60d0f2 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -13,7 +13,6 @@ from __future__ import absolute_import import pytest -from sagemaker.model import Model from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder, Mode @@ -29,11 +28,6 @@ logger = logging.getLogger(__name__) -MODEL_DATA = "s3://bucket/model.tar.gz" -MODEL_IMAGE = "mi" -ROLE = "some-role" -HF_TASK = "fill-mask" - sample_input = { "inputs": "The man worked as a [MASK].", } @@ -86,17 +80,6 @@ def model_builder_model_schema_builder(): ) -@pytest.fixture -def model_builder_model_with_task_builder(): - model = Model( - MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name="bert-base-uncased", role=ROLE - ) - return ModelBuilder( - model_path=HF_DIR, - model=model, - ) - - @pytest.fixture def model_builder(request): return request.getfixturevalue(request.param) @@ -139,42 +122,3 @@ def test_pytorch_transformers_sagemaker_endpoint( assert ( False ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" - - -@pytest.mark.skipif( - PYTHON_VERSION_IS_NOT_310, - reason="Testing Optional task", -) -@pytest.mark.parametrize("model_builder", ["model_builder_model_with_task_builder"], indirect=True) -def test_happy_path_with_task_sagemaker_endpoint( - sagemaker_session, model_builder, gpu_instance_type, input -): - logger.info("Running in SAGEMAKER_ENDPOINT mode...") - caught_ex = None - - iam_client = sagemaker_session.boto_session.client("iam") - role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] - - model = model_builder.build( - mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session - ) - - with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): - try: - logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") - predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1) - logger.info("Endpoint successfully deployed.") - predictor.predict(input) - except Exception as e: - caught_ex = e - finally: - cleanup_model_resources( - sagemaker_session=model_builder.sagemaker_session, - model_name=model.name, - endpoint_name=model.endpoint_name, - ) - if caught_ex: - logger.exception(caught_ex) - assert ( - False - ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index ef2f0d4516..953cbe775c 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -85,7 +85,6 @@ }, limits={}, ) -HF_TASK = "audio-classification" @pytest.fixture @@ -1028,57 +1027,3 @@ def test_deploy_with_name_and_resources(sagemaker_session): async_inference_config_dict=None, live_logging=False, ) - - -@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) -@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) -def test_deploy_with_name_and_task(sagemaker_session): - sagemaker_session.sagemaker_config = {} - - model = Model( - MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session - ) - - endpoint_name = "testing-task-input" - predictor = model.deploy( - endpoint_name=endpoint_name, - instance_type=INSTANCE_TYPE, - initial_instance_count=INSTANCE_COUNT, - ) - - sagemaker_session.create_model.assert_called_with( - name=MODEL_IMAGE, - role=ROLE, - task=HF_TASK - ) - - assert isinstance(predictor, sagemaker.predictor.Predictor) - assert predictor.endpoint_name == endpoint_name - assert predictor.sagemaker_session == sagemaker_session - - -@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) -@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) -def test_deploy_with_name_and_without_task(sagemaker_session): - sagemaker_session.sagemaker_config = {} - - model = Model( - MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session - ) - - endpoint_name = "testing-without-task-input" - predictor = model.deploy( - endpoint_name=endpoint_name, - instance_type=INSTANCE_TYPE, - initial_instance_count=INSTANCE_COUNT, - ) - - sagemaker_session.create_model.assert_called_with( - name=MODEL_IMAGE, - role=ROLE, - task=None, - ) - - assert isinstance(predictor, sagemaker.predictor.Predictor) - assert predictor.endpoint_name == endpoint_name - assert predictor.sagemaker_session == sagemaker_session From bdeb84b683cf5cae837211c1f73984fc56742c8f Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Mon, 26 Feb 2024 19:16:13 +0000 Subject: [PATCH 03/11] Add override logic in ModelBuilder with task provided --- src/sagemaker/serve/builder/model_builder.py | 11 ++- src/sagemaker/serve/schema/task.json | 40 ++++----- .../sagemaker/serve/test_schema_builder.py | 55 ++++++++++++ .../serve/builder/test_model_builder.py | 88 +++++++++++++++++++ 4 files changed, 172 insertions(+), 22 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 8ca6a5d4ab..82cf59bf45 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -118,7 +118,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): into a stream. All translations between the server and the client are handled automatically with the specified input and output. model (Optional[Union[object, str]): Model object (with ``predict`` method to perform - inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or + inference) or a HuggingFace/JumpStart Model ID (followed by ``:task`` if you need + to override the task, e.g. bert-base-uncased:fill-mask). Either ``model`` or ``inference_spec`` is required for the model builder to build the artifact. inference_spec (InferenceSpec): The inference spec file with your customized ``invoke`` and ``load`` functions. @@ -205,6 +206,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): "help": ( 'Model object with "predict" method to perform inference ' "or HuggingFace/JumpStart Model ID" + "or if you need to override task, provide input as ModelID:Task" ) }, ) @@ -610,6 +612,10 @@ def build( self._is_custom_image_uri = self.image_uri is not None if isinstance(self.model, str): + model_task = None + if ":" in self.model: + model_task = self.model.split(":")[1] + self.model = self.model.split(":")[0] if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 @@ -619,7 +625,8 @@ def build( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - model_task = hf_model_md.get("pipeline_tag") + if model_task is None: + model_task = hf_model_md.get("pipeline_tag") if self.schema_builder is None and model_task: self._schema_builder_init(model_task) diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json index 9ee6d186a2..3fc44156f4 100644 --- a/src/sagemaker/serve/schema/task.json +++ b/src/sagemaker/serve/schema/task.json @@ -1,12 +1,12 @@ { "fill-mask": { - "sample_inputs": { + "sample_inputs": { "properties": { - "inputs": "Paris is the of France.", + "inputs": "Paris is the [MASK] of France.", "parameters": {} } }, - "sample_outputs": { + "sample_outputs": { "properties": [ { "sequence": "Paris is the capital of France.", @@ -14,15 +14,15 @@ } ] } - }, + }, "question-answering": { - "sample_inputs": { + "sample_inputs": { "properties": { "context": "I have a German Shepherd dog, named Coco.", "question": "What is my dog's breed?" } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "answer": "German Shepherd", @@ -32,15 +32,15 @@ } ] } - }, + }, "text-classification": { - "sample_inputs": { + "sample_inputs": { "properties": { "inputs": "Where is the capital of France?, Paris is the capital of France.", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "label": "entailment", @@ -48,20 +48,20 @@ } ] } - }, - "text-generation": { - "sample_inputs": { + }, + "text-generation": { + "sample_inputs": { "properties": { "inputs": "Hello, I'm a language model", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ - { - "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" - } + { + "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" + } ] } - } + } } diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index 3816985d8f..44c95809d8 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -99,3 +99,58 @@ def test_model_builder_negative_path(sagemaker_session): match="Error Message: Schema builder for text-to-image could not be found.", ): model_builder.build(sagemaker_session=sagemaker_session) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Schema Builder Simplification feature", +) +def test_model_builder_happy_path_with_task_provided(sagemaker_session, gpu_instance_type): + model_builder = ModelBuilder(model="bert-base-uncased:fill-mask") + + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + inputs, outputs = task.retrieve_local_schemas("fill-mask") + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + caught_ex = None + try: + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy( + role=role_arn, instance_count=1, instance_type=gpu_instance_type + ) + + predicted_outputs = predictor.predict(inputs) + assert predicted_outputs is not None + + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"{caught_ex} was thrown when running transformers sagemaker endpoint test" + + +def test_model_builder_negative_path_with_invalid_task(sagemaker_session): + model_builder = ModelBuilder(model="bert-base-uncased:invalid-task") + + with pytest.raises( + TaskNotFoundException, + match="Error Message: Schema builder for invalid-task could not be found.", + ): + model_builder.build(sagemaker_session=sagemaker_session) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index becf63ab41..1e4cedb240 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1077,3 +1077,91 @@ def test_build_negative_path_when_schema_builder_not_present( "Error Message: Schema builder for text-to-image could not be found.", lambda: model_builder.build(sagemaker_session=mock_session), ) + + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_happy_path_override_with_task_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_hf_model, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="bert-base-uncased:text-generation") + model_builder.build(sagemaker_session=mock_session) + + self.assertIsNotNone(model_builder.schema_builder) + sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation") + self.assertEqual( + sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"] + ) + self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) + + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_negative_path_override_with_task_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="bert-base-uncased:invalid-task") + + self.assertRaisesRegexp( + TaskNotFoundException, + "Error Message: Schema builder for invalid-task could not be found.", + lambda: model_builder.build(sagemaker_session=mock_session), + ) From 5f2bcea7a9d8b97b7497ecebf8739e0d113d94ee Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Mon, 26 Feb 2024 19:21:53 +0000 Subject: [PATCH 04/11] Adjusted formatting --- src/sagemaker/serve/schema/task.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json index 3fc44156f4..c897f4abec 100644 --- a/src/sagemaker/serve/schema/task.json +++ b/src/sagemaker/serve/schema/task.json @@ -1,12 +1,12 @@ { "fill-mask": { - "sample_inputs": { + "sample_inputs": { "properties": { "inputs": "Paris is the [MASK] of France.", "parameters": {} } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "sequence": "Paris is the capital of France.", @@ -14,15 +14,15 @@ } ] } - }, + }, "question-answering": { - "sample_inputs": { + "sample_inputs": { "properties": { "context": "I have a German Shepherd dog, named Coco.", "question": "What is my dog's breed?" } - }, - "sample_outputs": { + }, + "sample_outputs": { "properties": [ { "answer": "German Shepherd", From 15a501d3ecd01b2292e9761de0c81e148a6f6986 Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Tue, 27 Feb 2024 22:22:51 +0000 Subject: [PATCH 05/11] Add extra unit tests for invalid inputs --- src/sagemaker/huggingface/llm_utils.py | 3 +- src/sagemaker/serve/builder/model_builder.py | 3 +- .../serve/builder/test_model_builder.py | 44 +++++++++++++++---- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index 1a2abfb2e4..7002d4e7ea 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = Returns: dict: The model metadata retrieved with the HuggingFace API """ - + if len(model_id) == 0: + raise ValueError("Model ID is empty. Please provide a valid Model ID.") hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}" hf_model_metadata_json = None try: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 82cf59bf45..faae6fe0f9 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -627,9 +627,8 @@ def build( if model_task is None: model_task = hf_model_md.get("pipeline_tag") - if self.schema_builder is None and model_task: + if self.schema_builder is None and model_task is not None: self._schema_builder_init(model_task) - if model_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() else: diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1e4cedb240..fb5c7e41fb 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1072,7 +1072,7 @@ def test_build_negative_path_when_schema_builder_not_present( model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") - self.assertRaisesRegexp( + self.assertRaisesRegex( TaskNotFoundException, "Error Message: Schema builder for text-to-image could not be found.", lambda: model_builder.build(sagemaker_session=mock_session), @@ -1131,7 +1131,7 @@ def test_build_happy_path_override_with_task_provided( @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_build_negative_path_override_with_task_provided( + def test_build_task_override_with_invalid_task_provided( self, mock_serveSettings, mock_model_uris_retrieve, @@ -1157,11 +1157,39 @@ def test_build_negative_path_override_with_task_provided( mock_model_urllib.request.Request.side_effect = Mock() mock_image_uris_retrieve.return_value = "https://some-image-uri" + model_ids_with_invalid_task = ["bert-base-uncased:invalid-task", "bert-base-uncased:"] + for model_id in model_ids_with_invalid_task: + model_builder = ModelBuilder(model=model_id) + + provided_task = model_id.split(":")[1] + self.assertRaisesRegex( + TaskNotFoundException, + f"Error Message: Schema builder for {provided_task} could not be found.", + lambda: model_builder.build(sagemaker_session=mock_session), + ) - model_builder = ModelBuilder(model="bert-base-uncased:invalid-task") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_task_override_with_invalid_model_provided( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_image_uris_retrieve, + ): + # Setup mocks - self.assertRaisesRegexp( - TaskNotFoundException, - "Error Message: Schema builder for invalid-task could not be found.", - lambda: model_builder.build(sagemaker_session=mock_session), - ) + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + invalid_model_ids_with_task = [":fill-mask", "bert-base-uncased;fill-mask"] + + for model_id in invalid_model_ids_with_task: + model_builder = ModelBuilder(model=model_id) + with self.assertRaises(Exception): + model_builder.build(sagemaker_session=mock_session) From c0e505d8b04f0ea1a68df6d62b5e66453e51c8ac Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Wed, 28 Feb 2024 22:16:20 +0000 Subject: [PATCH 06/11] Address PR comments --- src/sagemaker/huggingface/llm_utils.py | 2 +- tests/unit/sagemaker/serve/utils/test_task.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index 7002d4e7ea..de5e624dbc 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -81,7 +81,7 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = Returns: dict: The model metadata retrieved with the HuggingFace API """ - if len(model_id) == 0: + if not model_id: raise ValueError("Model ID is empty. Please provide a valid Model ID.") hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}" hf_model_metadata_json = None diff --git a/tests/unit/sagemaker/serve/utils/test_task.py b/tests/unit/sagemaker/serve/utils/test_task.py index 78553968e1..431888e249 100644 --- a/tests/unit/sagemaker/serve/utils/test_task.py +++ b/tests/unit/sagemaker/serve/utils/test_task.py @@ -18,7 +18,7 @@ from sagemaker.serve.utils import task -EXPECTED_INPUTS = {"inputs": "Paris is the of France.", "parameters": {}} +EXPECTED_INPUTS = {"inputs": "Paris is the [MASK] of France.", "parameters": {}} EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}] HF_INVALID_TASK = "not-present-task" From 86b6da695231d776f5d79a373efb5899e4cf3aeb Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Thu, 29 Feb 2024 22:15:00 +0000 Subject: [PATCH 07/11] Add more test inputs to integration test --- .../integ/sagemaker/serve/test_schema_builder.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index 44c95809d8..cd08ac7b10 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -105,15 +105,24 @@ def test_model_builder_negative_path(sagemaker_session): PYTHON_VERSION_IS_NOT_310, reason="Testing Schema Builder Simplification feature", ) -def test_model_builder_happy_path_with_task_provided(sagemaker_session, gpu_instance_type): - model_builder = ModelBuilder(model="bert-base-uncased:fill-mask") +@pytest.mark.parametrize( + "model_id, task_provided", + [ + ("bert-base-uncased", "fill-mask"), + ("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"), + ], +) +def test_model_builder_happy_path_with_task_provided( + model_id, task_provided, sagemaker_session, gpu_instance_type +): + model_builder = ModelBuilder(model=f"{model_id}:{task_provided}") model = model_builder.build(sagemaker_session=sagemaker_session) assert model is not None assert model_builder.schema_builder is not None - inputs, outputs = task.retrieve_local_schemas("fill-mask") + inputs, outputs = task.retrieve_local_schemas(task_provided) assert model_builder.schema_builder.sample_input == inputs assert model_builder.schema_builder.sample_output == outputs From 70c9d4f322bbe5f9fe25ae57795ee7ed0d9b7cbc Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Wed, 6 Mar 2024 03:34:23 +0000 Subject: [PATCH 08/11] Add model_metadata field to ModelBuilder --- src/sagemaker/serve/builder/model_builder.py | 16 ++++++------ .../sagemaker/serve/test_schema_builder.py | 6 +++-- .../serve/builder/test_model_builder.py | 25 ++++++++++++------- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index faae6fe0f9..9ad5ef1547 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -118,9 +118,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): into a stream. All translations between the server and the client are handled automatically with the specified input and output. model (Optional[Union[object, str]): Model object (with ``predict`` method to perform - inference) or a HuggingFace/JumpStart Model ID (followed by ``:task`` if you need - to override the task, e.g. bert-base-uncased:fill-mask). Either ``model`` or - ``inference_spec`` is required for the model builder to build the artifact. + inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec`` + is required for the model builder to build the artifact. inference_spec (InferenceSpec): The inference spec file with your customized ``invoke`` and ``load`` functions. image_uri (Optional[str]): The container image uri (which is derived from a @@ -140,6 +139,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, ``TRITON``, and``TGI``. + model_metadata (Optional[Dict[str, str]): Dictionary used to override the HuggingFace + model metadata. """ model_path: Optional[str] = field( @@ -206,7 +207,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): "help": ( 'Model object with "predict" method to perform inference ' "or HuggingFace/JumpStart Model ID" - "or if you need to override task, provide input as ModelID:Task" ) }, ) @@ -237,6 +237,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): model_server: Optional[ModelServer] = field( default=None, metadata={"help": "Define the model server to deploy to."} ) + model_metadata: Optional[Dict[str, str]] = field( + default=None, metadata={"help": "Define the model metadata to override"} + ) def _build_validations(self): """Placeholder docstring""" @@ -613,9 +616,8 @@ def build( if isinstance(self.model, str): model_task = None - if ":" in self.model: - model_task = self.model.split(":")[1] - self.model = self.model.split(":")[0] + if self.model_metadata: + model_task = self.model_metadata.get("HF_TASK") if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index cd08ac7b10..2b6ac48460 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -115,7 +115,7 @@ def test_model_builder_negative_path(sagemaker_session): def test_model_builder_happy_path_with_task_provided( model_id, task_provided, sagemaker_session, gpu_instance_type ): - model_builder = ModelBuilder(model=f"{model_id}:{task_provided}") + model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided}) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -156,7 +156,9 @@ def test_model_builder_happy_path_with_task_provided( def test_model_builder_negative_path_with_invalid_task(sagemaker_session): - model_builder = ModelBuilder(model="bert-base-uncased:invalid-task") + model_builder = ModelBuilder( + model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"} + ) with pytest.raises( TaskNotFoundException, diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index fb5c7e41fb..f208d0e9fc 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1114,7 +1114,9 @@ def test_build_happy_path_override_with_task_provided( mock_image_uris_retrieve.return_value = "https://some-image-uri" - model_builder = ModelBuilder(model="bert-base-uncased:text-generation") + model_builder = ModelBuilder( + model="bert-base-uncased", model_metadata={"HF_TASK": "text-generation"} + ) model_builder.build(sagemaker_session=mock_session) self.assertIsNotNone(model_builder.schema_builder) @@ -1157,11 +1159,14 @@ def test_build_task_override_with_invalid_task_provided( mock_model_urllib.request.Request.side_effect = Mock() mock_image_uris_retrieve.return_value = "https://some-image-uri" - model_ids_with_invalid_task = ["bert-base-uncased:invalid-task", "bert-base-uncased:"] + model_ids_with_invalid_task = { + "bert-base-uncased": "invalid-task", + "bert-large-uncased-whole-word-masking-finetuned-squad": "", + } for model_id in model_ids_with_invalid_task: - model_builder = ModelBuilder(model=model_id) + provided_task = model_ids_with_invalid_task[model_id] + model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": provided_task}) - provided_task = model_id.split(":")[1] self.assertRaisesRegex( TaskNotFoundException, f"Error Message: Schema builder for {provided_task} could not be found.", @@ -1187,9 +1192,11 @@ def test_build_task_override_with_invalid_model_provided( mock_model_uris_retrieve.side_effect = KeyError mock_image_uris_retrieve.return_value = "https://some-image-uri" - invalid_model_ids_with_task = [":fill-mask", "bert-base-uncased;fill-mask"] + invalid_model_id = "" + provided_task = "fill-mask" - for model_id in invalid_model_ids_with_task: - model_builder = ModelBuilder(model=model_id) - with self.assertRaises(Exception): - model_builder.build(sagemaker_session=mock_session) + model_builder = ModelBuilder( + model=invalid_model_id, model_metadata={"HF_TASK": provided_task} + ) + with self.assertRaises(Exception): + model_builder.build(sagemaker_session=mock_session) From bd83555886a007b1c172461bb34e2ce7ddba46f0 Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Wed, 6 Mar 2024 20:58:30 +0000 Subject: [PATCH 09/11] Update doc --- src/sagemaker/serve/builder/model_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 9ad5ef1547..6b63c03a53 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -140,7 +140,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, ``TRITON``, and``TGI``. model_metadata (Optional[Dict[str, str]): Dictionary used to override the HuggingFace - model metadata. + model metadata. Currently ``HF_TASK`` is overridable. """ model_path: Optional[str] = field( @@ -237,8 +237,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): model_server: Optional[ModelServer] = field( default=None, metadata={"help": "Define the model server to deploy to."} ) - model_metadata: Optional[Dict[str, str]] = field( - default=None, metadata={"help": "Define the model metadata to override"} + model_metadata: Optional[Dict[str, Any]] = field( + default=None, metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"} ) def _build_validations(self): From c06ef537b673034ed2b817237dab38b3c7222eef Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Wed, 6 Mar 2024 21:31:38 +0000 Subject: [PATCH 10/11] Update doc --- src/sagemaker/serve/builder/model_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 25385c6709..6f91f948cc 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -145,7 +145,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, ``TRITON``, and``TGI``. - model_metadata (Optional[Dict[str, str]): Dictionary used to override the HuggingFace + model_metadata (Optional[Dict[str, Any]): Dictionary used to override the HuggingFace model metadata. Currently ``HF_TASK`` is overridable. """ From 5b9f374df0993a16e2348fe812f432566b6d1b45 Mon Sep 17 00:00:00 2001 From: Xiong Zeng Date: Wed, 6 Mar 2024 21:38:02 +0000 Subject: [PATCH 11/11] Adjust formatting --- src/sagemaker/serve/builder/model_builder.py | 3 ++- tests/unit/sagemaker/serve/builder/test_model_builder.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 6f91f948cc..4d1e51cb26 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -244,7 +244,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): default=None, metadata={"help": "Define the model server to deploy to."} ) model_metadata: Optional[Dict[str, Any]] = field( - default=None, metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"} + default=None, + metadata={"help": "Define the model metadata to override, currently supports `HF_TASK`"}, ) def _build_validations(self): diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index bec3db2e2f..3b60d13dfb 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1715,4 +1715,4 @@ def test_build_task_override_with_invalid_model_provided( model=invalid_model_id, model_metadata={"HF_TASK": provided_task} ) with self.assertRaises(Exception): - model_builder.build(sagemaker_session=mock_session) \ No newline at end of file + model_builder.build(sagemaker_session=mock_session)