Skip to content

ModelBuilder to fetch local schema when no SchemaBuilder present. #4434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 23, 2024
80 changes: 80 additions & 0 deletions src/sagemaker/image_uri_config/tasks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"description": "Sample Task Inputs and Outputs",
"fill-mask": {
"ref": "https://huggingface.co/tasks/fill-mask",
"inputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/input.json",
"properties": {
"inputs": "Paris is the <mask> of France.",
"parameters": {}
}
},
"outputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/fill-mask/spec/output.json",
"properties": [
{
"sequence": "Paris is the capital of France.",
"score": 0.7
}
]
}
},
"question-answering": {
"ref": "https://huggingface.co/tasks/question-answering",
"inputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/input.json",
"properties": {
"context": "I have a German Shepherd dog, named Coco.",
"question": "What is my dog's breed?"
}
},
"outputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/question-answering/spec/output.json",
"properties": [
{
"answer": "German Shepherd",
"score": 0.972,
"start": 9,
"end": 24
}
]
}
},
"text-classification": {
"ref": "https://huggingface.co/tasks/text-classification",
"inputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/input.json",
"properties": {
"inputs": "Where is the capital of France?, Paris is the capital of France.",
"parameters": {}
}
},
"outputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-classification/spec/output.json",
"properties": [
{
"label": "entailment",
"score": 0.997
}
]
}
},
"text-generation": {
"ref": "https://huggingface.co/tasks/text-generation",
"inputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/input.json",
"properties": {
"inputs": "Hello, I'm a language model",
"parameters": {}
}
},
"outputs": {
"ref": "https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/text-generation/spec/output.json",
"properties": [
{
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
}
]
}
}
}
25 changes: 23 additions & 2 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pathlib import Path

from sagemaker import Session
from sagemaker import Session, task
from sagemaker.model import Model
from sagemaker.base_predictor import PredictorBase
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
Expand All @@ -38,6 +38,7 @@
from sagemaker.predictor import Predictor
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.utils.exceptions import TaskNotFoundException
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
from sagemaker.serve.detector.image_detector import (
auto_detect_container,
Expand Down Expand Up @@ -614,7 +615,12 @@ def build(
hf_model_md = get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705

model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None:
self._schema_builder_init(model_task)

if model_task == "text-generation": # pylint: disable=R1705
return self._build_for_tgi()
else:
return self._build_for_transformers()
Expand Down Expand Up @@ -672,3 +678,18 @@ def validate(self, model_dir: str) -> Type[bool]:
"""

return get_metadata(model_dir)

def _schema_builder_init(self, model_task: str):
"""Initialize the schema builder

Args:
model_task (str): Required, the task name

Raises:
TaskNotFoundException: If the I/O schema for the given task is not found.
"""
try:
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
except ValueError:
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")
9 changes: 9 additions & 0 deletions src/sagemaker/serve/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException):

def __init__(self, message):
super().__init__(message=message)


class TaskNotFoundException(ModelBuilderException):
"""Raise when task could not be found"""

fmt = "Error Message: {message}"

def __init__(self, message):
super().__init__(message=message)
49 changes: 49 additions & 0 deletions src/sagemaker/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve task fallback input/output schema"""
from __future__ import absolute_import

import json
import os
from typing import Any, Tuple


def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
"""Retrieves task sample inputs and outputs locally.

Args:
task (str): Required, the task name

Returns:
Tuple[Any, Any]: A tuple that contains the sample input,
at index 0, and output schema, at index 1.

Raises:
ValueError: If no tasks config found or the task does not exist in the local config.
"""
task_io_config_path = os.path.join(os.path.dirname(__file__), "image_uri_config", "tasks.json")
try:
with open(task_io_config_path) as f:
task_io_config = json.load(f)
task_io_schemas = task_io_config.get(task, None)

if task_io_schemas is None:
raise ValueError(f"Could not find {task} I/O schema.")

sample_schema = (
task_io_schemas["inputs"]["properties"],
task_io_schemas["outputs"]["properties"],
)
return sample_schema
except FileNotFoundError:
raise ValueError("Could not find tasks config file.")
46 changes: 46 additions & 0 deletions tests/integ/sagemaker/serve/test_schema_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker import task
from sagemaker.serve.builder.model_builder import ModelBuilder

import logging

logger = logging.getLogger(__name__)


def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session):
model_builder = ModelBuilder(model="bert-base-uncased")

model = model_builder.build(sagemaker_session=sagemaker_session)

assert model is not None
assert model_builder.schema_builder is not None

inputs, outputs = task.retrieve_local_schemas("fill-mask")
assert model_builder.schema_builder.sample_input == inputs
assert model_builder.schema_builder.sample_output == outputs


def test_model_builder_happy_path_with_only_model_id_question_answering(sagemaker_session):
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad")

model = model_builder.build(sagemaker_session=sagemaker_session)

assert model is not None
assert model_builder.schema_builder is not None

inputs, outputs = task.retrieve_local_schemas("question-answering")
assert model_builder.schema_builder.sample_input == inputs
assert model_builder.schema_builder.sample_output == outputs
38 changes: 38 additions & 0 deletions tests/unit/sagemaker/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
from sagemaker import task

EXPECTED_INPUTS = {"inputs": "Paris is the <mask> of France.", "parameters": {}}
EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}]


def test_retrieve_local_schemas_success():
inputs, outputs = task.retrieve_local_schemas("fill-mask")

assert inputs == EXPECTED_INPUTS
assert outputs == EXPECTED_OUTPUTS


def test_retrieve_local_schemas_text_generation_success():
inputs, outputs = task.retrieve_local_schemas("text-generation")

assert inputs is not None
assert outputs is not None


def test_retrieve_local_schemas_throws():
with pytest.raises(ValueError):
task.retrieve_local_schemas("not-present-task")