From b2d0929f53ca1969c105e9464b934af30e427e09 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 15 Feb 2024 14:09:33 -0800 Subject: [PATCH 01/13] Fetch Schema locally --- src/sagemaker/image_uri_config/tasks.json | 80 +++++++++++++++++ src/sagemaker/serve/builder/model_builder.py | 28 +++++- src/sagemaker/task.py | 91 ++++++++++++++++++++ tests/unit/sagemaker/test_task.py | 31 +++++++ 4 files changed, 228 insertions(+), 2 deletions(-) create mode 100644 src/sagemaker/image_uri_config/tasks.json create mode 100644 src/sagemaker/task.py create mode 100644 tests/unit/sagemaker/test_task.py diff --git a/src/sagemaker/image_uri_config/tasks.json b/src/sagemaker/image_uri_config/tasks.json new file mode 100644 index 0000000000..5ba18de3fe --- /dev/null +++ b/src/sagemaker/image_uri_config/tasks.json @@ -0,0 +1,80 @@ +{ + "description": "Sample Task Inputs and Outputs", + "fill-mask": { + "ref": "https://huggingface.co/tasks/fill-mask", + "inputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/input.json", + "properties": { + "inputs": "Paris is the of France.", + "parameters": {} + } + }, + "outputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/output.json", + "properties": [ + { + "sequence": "Paris is the capital of France.", + "score": 0.7 + } + ] + } + }, + "question-answering": { + "ref": "https://huggingface.co/tasks/question-answering", + "inputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/input.json", + "properties": { + "context": "I have a German Shepherd dog, named Coco.", + "question": "What is my dog's breed?" + } + }, + "outputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/output.json", + "properties": [ + { + "answer": "German Shepherd", + "score": 0.972, + "start": 9, + "end": 24 + } + ] + } + }, + "text-classification": { + "ref": "https://huggingface.co/tasks/text-classification", + "inputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/input.json", + "properties": { + "inputs": "Where is the capital of France?, Paris is the capital of France.", + "parameters": {} + } + }, + "outputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/output.json", + "properties": [ + { + "label": "entailment", + "score": 0.997 + } + ] + } + }, + "text-generation": { + "ref": "https://huggingface.co/tasks/text-generation", + "inputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/input.json", + "properties": { + "inputs": "Hello, I'm a language model", + "parameters": {} + } + }, + "outputs": { + "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/output.json", + "properties": [ + { + "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" + } + ] + } + } +} \ No newline at end of file diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 0ade8096f6..c4716cc7ab 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -20,7 +20,7 @@ from pathlib import Path -from sagemaker import Session +from sagemaker import Session, task from sagemaker.model import Model from sagemaker.base_predictor import PredictorBase from sagemaker.serializers import NumpySerializer, TorchTensorSerializer @@ -609,12 +609,20 @@ def build( if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 + if self.schema_builder is None: + self.__schema_builder_init("text-generation") + return self._build_for_djl() else: hf_model_md = get_huggingface_model_metadata( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705 + + hf_task = hf_model_md.get("pipeline_tag") + if self.schema_builder is None: + self.__schema_builder_init(hf_task) + + if hf_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() else: return self._build_for_transformers() @@ -672,3 +680,19 @@ def validate(self, model_dir: str) -> Type[bool]: """ return get_metadata(model_dir) + + def __schema_builder_init(self, model_task: str): + """Initialize the""" + sample_input, sample_output = None, None + + try: + sample_input, sample_output = task.retrieve_local_schemas(model_task) + except ValueError: + # TODO: try to retrieve schemas remotely + pass + + if sample_input and sample_output: + self.schema_builder = SchemaBuilder(sample_input, sample_output) + else: + # TODO: Raise ClientError + pass diff --git a/src/sagemaker/task.py b/src/sagemaker/task.py new file mode 100644 index 0000000000..479e217ace --- /dev/null +++ b/src/sagemaker/task.py @@ -0,0 +1,91 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Accessors to retrieve task fallback input/output schema""" +from __future__ import absolute_import + +import json +import os +from enum import Enum +from typing import Any, Tuple + + +class TASK(str, Enum): + """Enum class for tasks""" + + AUDIO_CLASSIFICATION = "audio-classification" + AUDIO_TO_AUDIO = "audio-to-audio" + AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition" + CONVERSATIONAL = "conversational" + DEPTH_ESTIMATION = "depth-estimation" + DOCUMENT_QUESTION_ANSWERING = "document-question-answering" + FEATURE_EXTRACTION = "feature-extraction" + FILL_MASK = "fill-mask" + IMAGE_CLASSIFICATION = "image-classification" + IMAGE_SEGMENTATION = "image-segmentation" + IMAGE_TO_IMAGE = "image-to-image" + IMAGE_TO_TEXT = "image-to-text" + MASK_GENERATION = "mask-generation" + OBJECT_DETECTION = "object-detection" + PLACEHOLDER = "placeholder" + QUESTION_ANSWERING = "question-answering" + REINFORCEMENT_LEARNING = "reinforcement-learning" + SENTENCE_SIMILARITY = "sentence-similarity" + SUMMARIZATION = "summarization" + TABLE_QUESTION_ANSWERING = "table-question-answering" + TABULAR_CLASSIFICATION = "tabular-classification" + TEXT_CLASSIFICATION = "text-classification" + TEXT_GENERATION = "text-generation" + TEXT_TO_AUDIO = "text-to-audio" + TEXT_TO_SPEECH = "text-to-speech" + TEXT_TO_VIDEO = "text-to-video" + TEXT_2_TEXT_GENERATION = "text2text-generation" + TOKEN_CLASSIFICATION = "token-classification" + TRANSLATION = "translation" + UNCONDITIONAL_IMAGE_GENERATION = "unconditional-image-generation" + VIDEO_CLASSIFICATION = "video-classification" + VISUAL_QUESTION_ANSWERING = "visual-question-answering" + ZERO_SHOT_CLASSIFICATION = "zero-shot-classification" + ZERO_SHOT_IMAGE_CLASSIFICATION = "zero-shot-image-classification" + ZERO_SHOT_OBJECT_DETECTION = "zero-shot-object-detection" + + +def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: + """Retrieves task sample inputs and outputs locally. + + Args: + task (str): Required, the task name + + Returns: + Tuple[Any, Any]: A tuple that contains the sample input, + at index 0, and output schema, at index 1. + + Raises: + ValueError: If no tasks config found or the task does not exist in the local config. + """ + task_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json") + try: + with open(task_path) as f: + task_config = json.load(f) + task_schema = task_config.get(task, None) + + if task_schema is None: + raise ValueError(f"Could not find {task} task schema.") + + sample_schema = ( + task_schema["inputs"]["properties"], + task_schema["outputs"]["properties"], + ) + return sample_schema + + except FileNotFoundError: + raise ValueError("Could not find tasks config file.") diff --git a/tests/unit/sagemaker/test_task.py b/tests/unit/sagemaker/test_task.py new file mode 100644 index 0000000000..ce76c13df9 --- /dev/null +++ b/tests/unit/sagemaker/test_task.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker import task + +EXPECTED_INPUTS = {"inputs": "Paris is the of France.", "parameters": {}} +EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}] + + +def test_retrieve_local_schemas_success(): + inputs, outputs = task.retrieve_local_schemas("fill-mask") + + assert inputs == EXPECTED_INPUTS + assert outputs == EXPECTED_OUTPUTS + + +def test_retrieve_local_schemas_throws(): + with pytest.raises(ValueError): + task.retrieve_local_schemas("invalid-task") From c4fb9f71e04ade422365ecc189facaccb8ee693d Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 15 Feb 2024 15:07:55 -0800 Subject: [PATCH 02/13] Fetch Schema locally --- src/sagemaker/serve/builder/model_builder.py | 2 +- src/sagemaker/task.py | 40 -------------------- 2 files changed, 1 insertion(+), 41 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index c4716cc7ab..ac819913db 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -681,7 +681,7 @@ def validate(self, model_dir: str) -> Type[bool]: return get_metadata(model_dir) - def __schema_builder_init(self, model_task: str): + def _schema_builder_init(self, model_task: str): """Initialize the""" sample_input, sample_output = None, None diff --git a/src/sagemaker/task.py b/src/sagemaker/task.py index 479e217ace..00c9eaf30a 100644 --- a/src/sagemaker/task.py +++ b/src/sagemaker/task.py @@ -19,46 +19,6 @@ from typing import Any, Tuple -class TASK(str, Enum): - """Enum class for tasks""" - - AUDIO_CLASSIFICATION = "audio-classification" - AUDIO_TO_AUDIO = "audio-to-audio" - AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition" - CONVERSATIONAL = "conversational" - DEPTH_ESTIMATION = "depth-estimation" - DOCUMENT_QUESTION_ANSWERING = "document-question-answering" - FEATURE_EXTRACTION = "feature-extraction" - FILL_MASK = "fill-mask" - IMAGE_CLASSIFICATION = "image-classification" - IMAGE_SEGMENTATION = "image-segmentation" - IMAGE_TO_IMAGE = "image-to-image" - IMAGE_TO_TEXT = "image-to-text" - MASK_GENERATION = "mask-generation" - OBJECT_DETECTION = "object-detection" - PLACEHOLDER = "placeholder" - QUESTION_ANSWERING = "question-answering" - REINFORCEMENT_LEARNING = "reinforcement-learning" - SENTENCE_SIMILARITY = "sentence-similarity" - SUMMARIZATION = "summarization" - TABLE_QUESTION_ANSWERING = "table-question-answering" - TABULAR_CLASSIFICATION = "tabular-classification" - TEXT_CLASSIFICATION = "text-classification" - TEXT_GENERATION = "text-generation" - TEXT_TO_AUDIO = "text-to-audio" - TEXT_TO_SPEECH = "text-to-speech" - TEXT_TO_VIDEO = "text-to-video" - TEXT_2_TEXT_GENERATION = "text2text-generation" - TOKEN_CLASSIFICATION = "token-classification" - TRANSLATION = "translation" - UNCONDITIONAL_IMAGE_GENERATION = "unconditional-image-generation" - VIDEO_CLASSIFICATION = "video-classification" - VISUAL_QUESTION_ANSWERING = "visual-question-answering" - ZERO_SHOT_CLASSIFICATION = "zero-shot-classification" - ZERO_SHOT_IMAGE_CLASSIFICATION = "zero-shot-image-classification" - ZERO_SHOT_OBJECT_DETECTION = "zero-shot-object-detection" - - def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: """Retrieves task sample inputs and outputs locally. From 3b5e6995364a469a3f397cdfaed7d70b211ca6ea Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 15 Feb 2024 15:11:17 -0800 Subject: [PATCH 03/13] Local schema --- src/sagemaker/serve/builder/model_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index ac819913db..00b158425e 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -610,7 +610,7 @@ def build( return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 if self.schema_builder is None: - self.__schema_builder_init("text-generation") + self._schema_builder_init("text-generation") return self._build_for_djl() else: @@ -620,7 +620,7 @@ def build( hf_task = hf_model_md.get("pipeline_tag") if self.schema_builder is None: - self.__schema_builder_init(hf_task) + self._schema_builder_init(hf_task) if hf_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() From edb5716be14e243df1b8543ededcf63206e10fb3 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 15 Feb 2024 15:17:09 -0800 Subject: [PATCH 04/13] Test local schemas --- src/sagemaker/serve/builder/model_builder.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 00b158425e..815a731996 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -605,24 +605,19 @@ def build( self.serve_settings = self._get_serve_setting() + sample_input, sample_output = task.retrieve_local_schemas("text-generation") + self.schema_builder = SchemaBuilder(sample_input, sample_output) + if isinstance(self.model, str): if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 - if self.schema_builder is None: - self._schema_builder_init("text-generation") - return self._build_for_djl() else: hf_model_md = get_huggingface_model_metadata( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - - hf_task = hf_model_md.get("pipeline_tag") - if self.schema_builder is None: - self._schema_builder_init(hf_task) - - if hf_task == "text-generation": # pylint: disable=R1705 + if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() else: return self._build_for_transformers() From 1b2a4fbc9b08e56717b98e0f04d80ba569655e77 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 15 Feb 2024 17:14:34 -0800 Subject: [PATCH 05/13] Testing --- src/sagemaker/serve/builder/model_builder.py | 24 +++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 815a731996..3727e4219b 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -605,19 +605,24 @@ def build( self.serve_settings = self._get_serve_setting() - sample_input, sample_output = task.retrieve_local_schemas("text-generation") - self.schema_builder = SchemaBuilder(sample_input, sample_output) - if isinstance(self.model, str): if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 return self._build_for_djl() else: + logger.info("******************************************************") + logger.info(f"schema_builder is None: {self.schema_builder is None}") + hf_model_md = get_huggingface_model_metadata( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705 + + hf_task = hf_model_md.get("pipeline_tag") + logger.info(f"hf_task: {hf_task}") + self._schema_builder_init(hf_task) + + if hf_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() else: return self._build_for_transformers() @@ -678,16 +683,19 @@ def validate(self, model_dir: str) -> Type[bool]: def _schema_builder_init(self, model_task: str): """Initialize the""" - sample_input, sample_output = None, None + sample_inputs, sample_outputs = None, None try: - sample_input, sample_output = task.retrieve_local_schemas(model_task) + sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task) + logger.info(f"Sample input: {sample_inputs}") + logger.info(f"Sample output: {sample_outputs}") except ValueError: # TODO: try to retrieve schemas remotely pass - if sample_input and sample_output: - self.schema_builder = SchemaBuilder(sample_input, sample_output) + if sample_inputs and sample_outputs: + self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) + logger.info(f"schema_builder is not None: {self.schema_builder is None}") else: # TODO: Raise ClientError pass From b2acab34ca702e87991e2318fe0630b939385f70 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 15 Feb 2024 17:21:56 -0800 Subject: [PATCH 06/13] Testing Schema --- src/sagemaker/serve/builder/model_builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 3727e4219b..86fa0b93a9 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -620,7 +620,8 @@ def build( hf_task = hf_model_md.get("pipeline_tag") logger.info(f"hf_task: {hf_task}") - self._schema_builder_init(hf_task) + if self.schema_builder is None: + self._schema_builder_init(hf_task) if hf_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() From ad2303f8c11dba53e75b330c09f43bd7e62ec9c8 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 15 Feb 2024 17:26:29 -0800 Subject: [PATCH 07/13] Schema for DJL --- src/sagemaker/serve/builder/model_builder.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 86fa0b93a9..73d924138e 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -609,17 +609,16 @@ def build( if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 + if self.schema_builder is None: + self._schema_builder_init("text-generation") + return self._build_for_djl() else: - logger.info("******************************************************") - logger.info(f"schema_builder is None: {self.schema_builder is None}") - hf_model_md = get_huggingface_model_metadata( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) hf_task = hf_model_md.get("pipeline_tag") - logger.info(f"hf_task: {hf_task}") if self.schema_builder is None: self._schema_builder_init(hf_task) @@ -688,15 +687,12 @@ def _schema_builder_init(self, model_task: str): try: sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task) - logger.info(f"Sample input: {sample_inputs}") - logger.info(f"Sample output: {sample_outputs}") except ValueError: # TODO: try to retrieve schemas remotely pass if sample_inputs and sample_outputs: self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) - logger.info(f"schema_builder is not None: {self.schema_builder is None}") else: # TODO: Raise ClientError pass From 17ae5d9d52f56cfd7c7abbca58daf7905a1aae19 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Fri, 16 Feb 2024 16:08:26 -0800 Subject: [PATCH 08/13] Add Integ tests --- src/sagemaker/serve/builder/model_builder.py | 29 ++++++------ src/sagemaker/serve/utils/exceptions.py | 9 ++++ src/sagemaker/task.py | 18 ++++---- .../sagemaker/serve/test_schema_builder.py | 46 +++++++++++++++++++ tests/unit/sagemaker/test_task.py | 9 +++- 5 files changed, 84 insertions(+), 27 deletions(-) create mode 100644 tests/integ/sagemaker/serve/test_schema_builder.py diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 73d924138e..1306a964ac 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -38,6 +38,7 @@ from sagemaker.predictor import Predictor from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.detector.image_detector import ( auto_detect_container, @@ -609,20 +610,17 @@ def build( if self._is_jumpstart_model_id(): return self._build_for_jumpstart() if self._is_djl(): # pylint: disable=R1705 - if self.schema_builder is None: - self._schema_builder_init("text-generation") - return self._build_for_djl() else: hf_model_md = get_huggingface_model_metadata( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - hf_task = hf_model_md.get("pipeline_tag") + model_task = hf_model_md.get("pipeline_tag") if self.schema_builder is None: - self._schema_builder_init(hf_task) + self._schema_builder_init(model_task) - if hf_task == "text-generation": # pylint: disable=R1705 + if model_task == "text-generation": # pylint: disable=R1705 return self._build_for_tgi() else: return self._build_for_transformers() @@ -682,17 +680,16 @@ def validate(self, model_dir: str) -> Type[bool]: return get_metadata(model_dir) def _schema_builder_init(self, model_task: str): - """Initialize the""" - sample_inputs, sample_outputs = None, None + """Initialize the schema builder + Args: + model_task (str): Required, the task name + + Raises: + TaskNotFoundException: If the I/O schema for the given task is not found. + """ try: sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task) - except ValueError: - # TODO: try to retrieve schemas remotely - pass - - if sample_inputs and sample_outputs: self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) - else: - # TODO: Raise ClientError - pass + except ValueError: + raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.") diff --git a/src/sagemaker/serve/utils/exceptions.py b/src/sagemaker/serve/utils/exceptions.py index f14677d93a..88a51c63f0 100644 --- a/src/sagemaker/serve/utils/exceptions.py +++ b/src/sagemaker/serve/utils/exceptions.py @@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException): def __init__(self, message): super().__init__(message=message) + + +class TaskNotFoundException(ModelBuilderException): + """Raise when task could not be found""" + + fmt = "Error Message: {message}" + + def __init__(self, message): + super().__init__(message=message) diff --git a/src/sagemaker/task.py b/src/sagemaker/task.py index 00c9eaf30a..f221c74890 100644 --- a/src/sagemaker/task.py +++ b/src/sagemaker/task.py @@ -15,7 +15,6 @@ import json import os -from enum import Enum from typing import Any, Tuple @@ -32,20 +31,19 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: Raises: ValueError: If no tasks config found or the task does not exist in the local config. """ - task_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json") + task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json") try: - with open(task_path) as f: - task_config = json.load(f) - task_schema = task_config.get(task, None) + with open(task_io_config_path) as f: + task_io_config = json.load(f) + task_io_schemas = task_io_config.get(task, None) - if task_schema is None: - raise ValueError(f"Could not find {task} task schema.") + if task_io_schemas is None: + raise ValueError(f"Could not find {task} I/O schema.") sample_schema = ( - task_schema["inputs"]["properties"], - task_schema["outputs"]["properties"], + task_io_schemas["inputs"]["properties"], + task_io_schemas["outputs"]["properties"], ) return sample_schema - except FileNotFoundError: raise ValueError("Could not find tasks config file.") diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py new file mode 100644 index 0000000000..679d1a0f47 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker import task +from sagemaker.serve.builder.model_builder import ModelBuilder + +import logging + +logger = logging.getLogger(__name__) + + +def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session): + model_builder = ModelBuilder(model="bert-base-uncased") + + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + inputs, outputs = task.retrieve_local_schemas("fill-mask") + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs + + +def test_model_builder_happy_path_with_only_model_id_question_answering(sagemaker_session): + model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad") + + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + inputs, outputs = task.retrieve_local_schemas("question-answering") + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs diff --git a/tests/unit/sagemaker/test_task.py b/tests/unit/sagemaker/test_task.py index ce76c13df9..b83392b61d 100644 --- a/tests/unit/sagemaker/test_task.py +++ b/tests/unit/sagemaker/test_task.py @@ -26,6 +26,13 @@ def test_retrieve_local_schemas_success(): assert outputs == EXPECTED_OUTPUTS +def test_retrieve_local_schemas_text_generation_success(): + inputs, outputs = task.retrieve_local_schemas("text-generation") + + assert inputs is not None + assert outputs is not None + + def test_retrieve_local_schemas_throws(): with pytest.raises(ValueError): - task.retrieve_local_schemas("invalid-task") + task.retrieve_local_schemas("not-present-task") From 98dd0a9521388b624f1f7966f251408b4615c260 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 19 Feb 2024 14:27:44 -0800 Subject: [PATCH 09/13] address PR comments --- src/sagemaker/image_uri_config/tasks.json | 29 +++------ src/sagemaker/serve/builder/model_builder.py | 5 +- src/sagemaker/serve/utils/exceptions.py | 2 +- src/sagemaker/{ => serve/utils}/task.py | 8 ++- .../sagemaker/serve/test_schema_builder.py | 59 ++++++++++++++++++- .../sagemaker/{ => serve/utils}/test_task.py | 17 +++++- 6 files changed, 88 insertions(+), 32 deletions(-) rename src/sagemaker/{ => serve/utils}/task.py (80%) rename tests/unit/sagemaker/{ => serve/utils}/test_task.py (69%) diff --git a/src/sagemaker/image_uri_config/tasks.json b/src/sagemaker/image_uri_config/tasks.json index 5ba18de3fe..b658e2d5bd 100644 --- a/src/sagemaker/image_uri_config/tasks.json +++ b/src/sagemaker/image_uri_config/tasks.json @@ -1,16 +1,12 @@ { - "description": "Sample Task Inputs and Outputs", "fill-mask": { - "ref": "https://huggingface.co/tasks/fill-mask", - "inputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/input.json", + "sample_inputs": { "properties": { "inputs": "Paris is the of France.", "parameters": {} } }, - "outputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/output.json", + "sample_outputs": { "properties": [ { "sequence": "Paris is the capital of France.", @@ -20,16 +16,13 @@ } }, "question-answering": { - "ref": "https://huggingface.co/tasks/question-answering", - "inputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/input.json", + "sample_inputs": { "properties": { "context": "I have a German Shepherd dog, named Coco.", "question": "What is my dog's breed?" } }, - "outputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/output.json", + "sample_outputs": { "properties": [ { "answer": "German Shepherd", @@ -41,16 +34,13 @@ } }, "text-classification": { - "ref": "https://huggingface.co/tasks/text-classification", - "inputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/input.json", + "sample_inputs": { "properties": { "inputs": "Where is the capital of France?, Paris is the capital of France.", "parameters": {} } }, - "outputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/output.json", + "sample_outputs": { "properties": [ { "label": "entailment", @@ -60,16 +50,13 @@ } }, "text-generation": { - "ref": "https://huggingface.co/tasks/text-generation", - "inputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/input.json", + "sample_inputs": { "properties": { "inputs": "Hello, I'm a language model", "parameters": {} } }, - "outputs": { - "ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/output.json", + "sample_outputs": { "properties": [ { "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 1306a964ac..11a291ee47 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -20,7 +20,7 @@ from pathlib import Path -from sagemaker import Session, task +from sagemaker import Session from sagemaker.model import Model from sagemaker.base_predictor import PredictorBase from sagemaker.serializers import NumpySerializer, TorchTensorSerializer @@ -38,6 +38,7 @@ from sagemaker.predictor import Predictor from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.detector.image_detector import ( @@ -617,7 +618,7 @@ def build( ) model_task = hf_model_md.get("pipeline_tag") - if self.schema_builder is None: + if self.schema_builder is None and model_task: self._schema_builder_init(model_task) if model_task == "text-generation": # pylint: disable=R1705 diff --git a/src/sagemaker/serve/utils/exceptions.py b/src/sagemaker/serve/utils/exceptions.py index 88a51c63f0..8132820cc0 100644 --- a/src/sagemaker/serve/utils/exceptions.py +++ b/src/sagemaker/serve/utils/exceptions.py @@ -63,7 +63,7 @@ def __init__(self, message): class TaskNotFoundException(ModelBuilderException): - """Raise when task could not be found""" + """Raise when HuggingFace task could not be found""" fmt = "Error Message: {message}" diff --git a/src/sagemaker/task.py b/src/sagemaker/serve/utils/task.py similarity index 80% rename from src/sagemaker/task.py rename to src/sagemaker/serve/utils/task.py index f221c74890..40b713f930 100644 --- a/src/sagemaker/task.py +++ b/src/sagemaker/serve/utils/task.py @@ -31,7 +31,9 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: Raises: ValueError: If no tasks config found or the task does not exist in the local config. """ - task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json") + # task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json") + c = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + task_io_config_path = os.path.join(c, "image_uri_config", "tasks.json") try: with open(task_io_config_path) as f: task_io_config = json.load(f) @@ -41,8 +43,8 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: raise ValueError(f"Could not find {task} I/O schema.") sample_schema = ( - task_io_schemas["inputs"]["properties"], - task_io_schemas["outputs"]["properties"], + task_io_schemas["sample_inputs"]["properties"], + task_io_schemas["sample_outputs"]["properties"], ) return sample_schema except FileNotFoundError: diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index 679d1a0f47..cd545c62fb 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -12,8 +12,19 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from sagemaker import task from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.utils import task + +import pytest + +from sagemaker.serve.utils.exceptions import TaskNotFoundException +from tests.integ.sagemaker.serve.constants import ( + PYTHON_VERSION_IS_NOT_310, + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, +) + +from tests.integ.timeout import timeout +from tests.integ.utils import cleanup_model_resources import logging @@ -33,7 +44,13 @@ def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session assert model_builder.schema_builder.sample_output == outputs -def test_model_builder_happy_path_with_only_model_id_question_answering(sagemaker_session): +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Schema Builder Simplification feature", +) +def test_model_builder_happy_path_with_only_model_id_question_answering( + sagemaker_session, gpu_instance_type +): model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad") model = model_builder.build(sagemaker_session=sagemaker_session) @@ -44,3 +61,41 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(sagemake inputs, outputs = task.retrieve_local_schemas("question-answering") assert model_builder.schema_builder.sample_input == inputs assert model_builder.schema_builder.sample_output == outputs + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + caught_ex = None + try: + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="JarvisTest")["Role"]["Arn"] + + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy( + role=role_arn, instance_count=1, instance_type=gpu_instance_type + ) + + predicted_outputs = predictor.predict(inputs) + assert predicted_outputs is not None + + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"{caught_ex} was thrown when running transformers sagemaker endpoint test" + + +def test_model_builder_negative_path(sagemaker_session): + model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") + + with pytest.raises( + TaskNotFoundException, + match="Error Message: Schema builder for text-to-image could not be found.", + ): + model_builder.build(sagemaker_session=sagemaker_session) diff --git a/tests/unit/sagemaker/test_task.py b/tests/unit/sagemaker/serve/utils/test_task.py similarity index 69% rename from tests/unit/sagemaker/test_task.py rename to tests/unit/sagemaker/serve/utils/test_task.py index b83392b61d..78553968e1 100644 --- a/tests/unit/sagemaker/test_task.py +++ b/tests/unit/sagemaker/serve/utils/test_task.py @@ -12,11 +12,15 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import patch + import pytest -from sagemaker import task + +from sagemaker.serve.utils import task EXPECTED_INPUTS = {"inputs": "Paris is the of France.", "parameters": {}} EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}] +HF_INVALID_TASK = "not-present-task" def test_retrieve_local_schemas_success(): @@ -34,5 +38,12 @@ def test_retrieve_local_schemas_text_generation_success(): def test_retrieve_local_schemas_throws(): - with pytest.raises(ValueError): - task.retrieve_local_schemas("not-present-task") + with pytest.raises(ValueError, match=f"Could not find {HF_INVALID_TASK} I/O schema."): + task.retrieve_local_schemas(HF_INVALID_TASK) + + +@patch("builtins.open") +def test_retrieve_local_schemas_file_not_found(mock_open): + mock_open.side_effect = FileNotFoundError + with pytest.raises(ValueError, match="Could not find tasks config file."): + task.retrieve_local_schemas(HF_INVALID_TASK) From 99d11cbf51ce1c986c49c1151012ef8f5fdc81dc Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Mon, 19 Feb 2024 14:35:47 -0800 Subject: [PATCH 10/13] Address PR Review Comments --- src/sagemaker/image_uri_config/tasks.json | 2 +- src/sagemaker/serve/utils/task.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/image_uri_config/tasks.json b/src/sagemaker/image_uri_config/tasks.json index b658e2d5bd..9ee6d186a2 100644 --- a/src/sagemaker/image_uri_config/tasks.json +++ b/src/sagemaker/image_uri_config/tasks.json @@ -64,4 +64,4 @@ ] } } -} \ No newline at end of file +} diff --git a/src/sagemaker/serve/utils/task.py b/src/sagemaker/serve/utils/task.py index 40b713f930..7727074882 100644 --- a/src/sagemaker/serve/utils/task.py +++ b/src/sagemaker/serve/utils/task.py @@ -31,9 +31,8 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: Raises: ValueError: If no tasks config found or the task does not exist in the local config. """ - # task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json") - c = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - task_io_config_path = os.path.join(c, "image_uri_config", "tasks.json") + config_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + task_io_config_path = os.path.join(config_dir, "image_uri_config", "tasks.json") try: with open(task_io_config_path) as f: task_io_config = json.load(f) From 38eed2faa00734b49782c7e84014807e478bc84a Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Tue, 20 Feb 2024 11:19:20 -0800 Subject: [PATCH 11/13] Add Unit tests --- .../sagemaker/serve/test_schema_builder.py | 2 +- .../serve/builder/test_model_builder.py | 92 +++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index cd545c62fb..3816985d8f 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -66,7 +66,7 @@ def test_model_builder_happy_path_with_only_model_id_question_answering( caught_ex = None try: iam_client = sagemaker_session.boto_session.client("iam") - role_arn = iam_client.get_role(RoleName="JarvisTest")["Role"]["Arn"] + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") predictor = model.deploy( diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 898536c03f..becf63ab41 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -18,6 +18,8 @@ from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils import task +from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.types import ModelServer from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG @@ -985,3 +987,93 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co build_result.deploy(mode=Mode.LOCAL_CONTAINER) self.assertEqual(builder.mode, Mode.LOCAL_CONTAINER) + + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_happy_path_when_schema_builder_not_present( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_hf_model, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="meta-llama/Llama-2-7b-hf") + model_builder.build(sagemaker_session=mock_session) + + self.assertIsNotNone(model_builder.schema_builder) + sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation") + self.assertEqual( + sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"] + ) + self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) + + @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_build_negative_path_when_schema_builder_not_present( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_hf_model, + ): + # Setup mocks + + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + # HF Pipeline Tag + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-to-image"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + # HF Model config + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4") + + self.assertRaisesRegexp( + TaskNotFoundException, + "Error Message: Schema builder for text-to-image could not be found.", + lambda: model_builder.build(sagemaker_session=mock_session), + ) From f78637c357f1dafc89c87becddafcbc156391315 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 22 Feb 2024 09:18:43 -0800 Subject: [PATCH 12/13] Address PR Comments --- MANIFEST.in | 1 + src/sagemaker/serve/schema/task.json | 67 ++++++++++++++++++++++++++++ src/sagemaker/serve/utils/task.py | 4 +- 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 src/sagemaker/serve/schema/task.json diff --git a/MANIFEST.in b/MANIFEST.in index f8d3d426f6..c5eeeed043 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ recursive-include src/sagemaker *.py include src/sagemaker/image_uri_config/*.json +include src/sagemaker/serve/schema/*.json include src/sagemaker/serve/requirements.txt recursive-include requirements * diff --git a/src/sagemaker/serve/schema/task.json b/src/sagemaker/serve/schema/task.json new file mode 100644 index 0000000000..9ee6d186a2 --- /dev/null +++ b/src/sagemaker/serve/schema/task.json @@ -0,0 +1,67 @@ +{ + "fill-mask": { + "sample_inputs": { + "properties": { + "inputs": "Paris is the of France.", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "sequence": "Paris is the capital of France.", + "score": 0.7 + } + ] + } + }, + "question-answering": { + "sample_inputs": { + "properties": { + "context": "I have a German Shepherd dog, named Coco.", + "question": "What is my dog's breed?" + } + }, + "sample_outputs": { + "properties": [ + { + "answer": "German Shepherd", + "score": 0.972, + "start": 9, + "end": 24 + } + ] + } + }, + "text-classification": { + "sample_inputs": { + "properties": { + "inputs": "Where is the capital of France?, Paris is the capital of France.", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "label": "entailment", + "score": 0.997 + } + ] + } + }, + "text-generation": { + "sample_inputs": { + "properties": { + "inputs": "Hello, I'm a language model", + "parameters": {} + } + }, + "sample_outputs": { + "properties": [ + { + "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" + } + ] + } + } +} diff --git a/src/sagemaker/serve/utils/task.py b/src/sagemaker/serve/utils/task.py index 7727074882..6f8786985c 100644 --- a/src/sagemaker/serve/utils/task.py +++ b/src/sagemaker/serve/utils/task.py @@ -31,8 +31,8 @@ def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: Raises: ValueError: If no tasks config found or the task does not exist in the local config. """ - config_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - task_io_config_path = os.path.join(config_dir, "image_uri_config", "tasks.json") + config_dir = os.path.dirname(os.path.dirname(__file__)) + task_io_config_path = os.path.join(config_dir, "schema", "task.json") try: with open(task_io_config_path) as f: task_io_config = json.load(f) From 8f6e6b29add3afe8f7a74f91785fc9628b3a1ad2 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 22 Feb 2024 09:23:07 -0800 Subject: [PATCH 13/13] Address PR Comments --- src/sagemaker/image_uri_config/tasks.json | 67 ----------------------- 1 file changed, 67 deletions(-) delete mode 100644 src/sagemaker/image_uri_config/tasks.json diff --git a/src/sagemaker/image_uri_config/tasks.json b/src/sagemaker/image_uri_config/tasks.json deleted file mode 100644 index 9ee6d186a2..0000000000 --- a/src/sagemaker/image_uri_config/tasks.json +++ /dev/null @@ -1,67 +0,0 @@ -{ - "fill-mask": { - "sample_inputs": { - "properties": { - "inputs": "Paris is the of France.", - "parameters": {} - } - }, - "sample_outputs": { - "properties": [ - { - "sequence": "Paris is the capital of France.", - "score": 0.7 - } - ] - } - }, - "question-answering": { - "sample_inputs": { - "properties": { - "context": "I have a German Shepherd dog, named Coco.", - "question": "What is my dog's breed?" - } - }, - "sample_outputs": { - "properties": [ - { - "answer": "German Shepherd", - "score": 0.972, - "start": 9, - "end": 24 - } - ] - } - }, - "text-classification": { - "sample_inputs": { - "properties": { - "inputs": "Where is the capital of France?, Paris is the capital of France.", - "parameters": {} - } - }, - "sample_outputs": { - "properties": [ - { - "label": "entailment", - "score": 0.997 - } - ] - } - }, - "text-generation": { - "sample_inputs": { - "properties": { - "inputs": "Hello, I'm a language model", - "parameters": {} - } - }, - "sample_outputs": { - "properties": [ - { - "generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my" - } - ] - } - } -}