Skip to content

ModelBuilder to fetch local schema when no SchemaBuilder present. #4434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 23, 2024
67 changes: 67 additions & 0 deletions src/sagemaker/image_uri_config/tasks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"fill-mask": {
"sample_inputs": {
"properties": {
"inputs": "Paris is the <mask> of France.",
"parameters": {}
}
},
"sample_outputs": {
"properties": [
{
"sequence": "Paris is the capital of France.",
"score": 0.7
}
]
}
},
"question-answering": {
"sample_inputs": {
"properties": {
"context": "I have a German Shepherd dog, named Coco.",
"question": "What is my dog's breed?"
}
},
"sample_outputs": {
"properties": [
{
"answer": "German Shepherd",
"score": 0.972,
"start": 9,
"end": 24
}
]
}
},
"text-classification": {
"sample_inputs": {
"properties": {
"inputs": "Where is the capital of France?, Paris is the capital of France.",
"parameters": {}
}
},
"sample_outputs": {
"properties": [
{
"label": "entailment",
"score": 0.997
}
]
}
},
"text-generation": {
"sample_inputs": {
"properties": {
"inputs": "Hello, I'm a language model",
"parameters": {}
}
},
"sample_outputs": {
"properties": [
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
]
}
}
}
24 changes: 23 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from sagemaker.predictor import Predictor
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.utils import task
from sagemaker.serve.utils.exceptions import TaskNotFoundException
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
from sagemaker.serve.detector.image_detector import (
auto_detect_container,
Expand Down Expand Up @@ -614,7 +616,12 @@ def build(
hf_model_md = get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705

model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task:
self._schema_builder_init(model_task)

if model_task == "text-generation": # pylint: disable=R1705
return self._build_for_tgi()
else:
return self._build_for_transformers()
Expand Down Expand Up @@ -672,3 +679,18 @@ def validate(self, model_dir: str) -> Type[bool]:
"""

return get_metadata(model_dir)

def _schema_builder_init(self, model_task: str):
"""Initialize the schema builder

Args:
model_task (str): Required, the task name

Raises:
TaskNotFoundException: If the I/O schema for the given task is not found.
"""
try:
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
except ValueError:
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")
9 changes: 9 additions & 0 deletions src/sagemaker/serve/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException):

def __init__(self, message):
super().__init__(message=message)


class TaskNotFoundException(ModelBuilderException):
"""Raise when HuggingFace task could not be found"""

fmt = "Error Message: {message}"

def __init__(self, message):
super().__init__(message=message)
50 changes: 50 additions & 0 deletions src/sagemaker/serve/utils/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve task fallback input/output schema"""
from __future__ import absolute_import

import json
import os
from typing import Any, Tuple


def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
"""Retrieves task sample inputs and outputs locally.

Args:
task (str): Required, the task name

Returns:
Tuple[Any, Any]: A tuple that contains the sample input,
at index 0, and output schema, at index 1.

Raises:
ValueError: If no tasks config found or the task does not exist in the local config.
"""
config_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
task_io_config_path = os.path.join(config_dir, "image_uri_config", "tasks.json")
try:
with open(task_io_config_path) as f:
task_io_config = json.load(f)
task_io_schemas = task_io_config.get(task, None)

if task_io_schemas is None:
raise ValueError(f"Could not find {task} I/O schema.")

sample_schema = (
task_io_schemas["sample_inputs"]["properties"],
task_io_schemas["sample_outputs"]["properties"],
)
return sample_schema
except FileNotFoundError:
raise ValueError("Could not find tasks config file.")
101 changes: 101 additions & 0 deletions tests/integ/sagemaker/serve/test_schema_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker.serve.builder.model_builder import ModelBuilder
from sagemaker.serve.utils import task

import pytest

from sagemaker.serve.utils.exceptions import TaskNotFoundException
from tests.integ.sagemaker.serve.constants import (
PYTHON_VERSION_IS_NOT_310,
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
)

from tests.integ.timeout import timeout
from tests.integ.utils import cleanup_model_resources

import logging

logger = logging.getLogger(__name__)


def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session):
model_builder = ModelBuilder(model="bert-base-uncased")

model = model_builder.build(sagemaker_session=sagemaker_session)

assert model is not None
assert model_builder.schema_builder is not None

inputs, outputs = task.retrieve_local_schemas("fill-mask")
assert model_builder.schema_builder.sample_input == inputs
assert model_builder.schema_builder.sample_output == outputs


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="Testing Schema Builder Simplification feature",
)
def test_model_builder_happy_path_with_only_model_id_question_answering(
sagemaker_session, gpu_instance_type
):
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad")

model = model_builder.build(sagemaker_session=sagemaker_session)

assert model is not None
assert model_builder.schema_builder is not None

inputs, outputs = task.retrieve_local_schemas("question-answering")
assert model_builder.schema_builder.sample_input == inputs
assert model_builder.schema_builder.sample_output == outputs

with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
caught_ex = None
try:
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]

logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
predictor = model.deploy(
role=role_arn, instance_count=1, instance_type=gpu_instance_type
)

predicted_outputs = predictor.predict(inputs)
assert predicted_outputs is not None

except Exception as e:
caught_ex = e
finally:
cleanup_model_resources(
sagemaker_session=model_builder.sagemaker_session,
model_name=model.name,
endpoint_name=model.endpoint_name,
)
if caught_ex:
logger.exception(caught_ex)
assert (
False
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"


def test_model_builder_negative_path(sagemaker_session):
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")

with pytest.raises(
TaskNotFoundException,
match="Error Message: Schema builder for text-to-image could not be found.",
):
model_builder.build(sagemaker_session=sagemaker_session)
92 changes: 92 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from sagemaker.serve.builder.model_builder import ModelBuilder
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils import task
from sagemaker.serve.utils.exceptions import TaskNotFoundException
from sagemaker.serve.utils.types import ModelServer
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG

Expand Down Expand Up @@ -985,3 +987,93 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
build_result.deploy(mode=Mode.LOCAL_CONTAINER)

self.assertEqual(builder.mode, Mode.LOCAL_CONTAINER)

@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.djl_inference.model.urllib")
@patch("sagemaker.djl_inference.model.json")
@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_happy_path_when_schema_builder_not_present(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_llm_utils_json,
mock_llm_utils_urllib,
mock_model_json,
mock_model_urllib,
mock_image_uris_retrieve,
mock_hf_model,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"}
mock_llm_utils_urllib.request.Request.side_effect = Mock()

# HF Model config
mock_model_json.load.return_value = {"some": "config"}
mock_model_urllib.request.Request.side_effect = Mock()

mock_image_uris_retrieve.return_value = "https://some-image-uri"

model_builder = ModelBuilder(model="meta-llama/Llama-2-7b-hf")
model_builder.build(sagemaker_session=mock_session)

self.assertIsNotNone(model_builder.schema_builder)
sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation")
self.assertEqual(
sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"]
)
self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output)

@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.djl_inference.model.urllib")
@patch("sagemaker.djl_inference.model.json")
@patch("sagemaker.huggingface.llm_utils.urllib")
@patch("sagemaker.huggingface.llm_utils.json")
@patch("sagemaker.model_uris.retrieve")
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_build_negative_path_when_schema_builder_not_present(
self,
mock_serveSettings,
mock_model_uris_retrieve,
mock_llm_utils_json,
mock_llm_utils_urllib,
mock_model_json,
mock_model_urllib,
mock_image_uris_retrieve,
mock_hf_model,
):
# Setup mocks

mock_setting_object = mock_serveSettings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

# HF Pipeline Tag
mock_model_uris_retrieve.side_effect = KeyError
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-to-image"}
mock_llm_utils_urllib.request.Request.side_effect = Mock()

# HF Model config
mock_model_json.load.return_value = {"some": "config"}
mock_model_urllib.request.Request.side_effect = Mock()

mock_image_uris_retrieve.return_value = "https://some-image-uri"

model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")

self.assertRaisesRegexp(
TaskNotFoundException,
"Error Message: Schema builder for text-to-image could not be found.",
lambda: model_builder.build(sagemaker_session=mock_session),
)
Loading