Skip to content

Commit 2a709b5

Browse files
makungaj1Jonathan Makunga
authored andcommitted
feat: ModelBuilder to fetch local schema when no SchemaBuilder present. (aws#4434)
* Fetch Schema locally * Fetch Schema locally * Local schema * Test local schemas * Testing * Testing Schema * Schema for DJL * Add Integ tests * address PR comments * Address PR Review Comments * Add Unit tests * Address PR Comments * Address PR Comments --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 3be85ce commit 2a709b5

File tree

8 files changed

+392
-1
lines changed

8 files changed

+392
-1
lines changed

MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
recursive-include src/sagemaker *.py
22

33
include src/sagemaker/image_uri_config/*.json
4+
include src/sagemaker/serve/schema/*.json
45
include src/sagemaker/serve/requirements.txt
56
recursive-include requirements *
67

src/sagemaker/serve/builder/model_builder.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from sagemaker.predictor import Predictor
3939
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
4040
from sagemaker.serve.spec.inference_spec import InferenceSpec
41+
from sagemaker.serve.utils import task
42+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4143
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
4244
from sagemaker.serve.detector.image_detector import (
4345
auto_detect_container,
@@ -616,7 +618,12 @@ def build(
616618
hf_model_md = get_huggingface_model_metadata(
617619
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
618620
)
619-
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
621+
622+
model_task = hf_model_md.get("pipeline_tag")
623+
if self.schema_builder is None and model_task:
624+
self._schema_builder_init(model_task)
625+
626+
if model_task == "text-generation": # pylint: disable=R1705
620627
return self._build_for_tgi()
621628
else:
622629
return self._build_for_transformers()
@@ -674,3 +681,18 @@ def validate(self, model_dir: str) -> Type[bool]:
674681
"""
675682

676683
return get_metadata(model_dir)
684+
685+
def _schema_builder_init(self, model_task: str):
686+
"""Initialize the schema builder
687+
688+
Args:
689+
model_task (str): Required, the task name
690+
691+
Raises:
692+
TaskNotFoundException: If the I/O schema for the given task is not found.
693+
"""
694+
try:
695+
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
696+
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
697+
except ValueError:
698+
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")

src/sagemaker/serve/schema/task.json

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
{
2+
"fill-mask": {
3+
"sample_inputs": {
4+
"properties": {
5+
"inputs": "Paris is the <mask> of France.",
6+
"parameters": {}
7+
}
8+
},
9+
"sample_outputs": {
10+
"properties": [
11+
{
12+
"sequence": "Paris is the capital of France.",
13+
"score": 0.7
14+
}
15+
]
16+
}
17+
},
18+
"question-answering": {
19+
"sample_inputs": {
20+
"properties": {
21+
"context": "I have a German Shepherd dog, named Coco.",
22+
"question": "What is my dog's breed?"
23+
}
24+
},
25+
"sample_outputs": {
26+
"properties": [
27+
{
28+
"answer": "German Shepherd",
29+
"score": 0.972,
30+
"start": 9,
31+
"end": 24
32+
}
33+
]
34+
}
35+
},
36+
"text-classification": {
37+
"sample_inputs": {
38+
"properties": {
39+
"inputs": "Where is the capital of France?, Paris is the capital of France.",
40+
"parameters": {}
41+
}
42+
},
43+
"sample_outputs": {
44+
"properties": [
45+
{
46+
"label": "entailment",
47+
"score": 0.997
48+
}
49+
]
50+
}
51+
},
52+
"text-generation": {
53+
"sample_inputs": {
54+
"properties": {
55+
"inputs": "Hello, I'm a language model",
56+
"parameters": {}
57+
}
58+
},
59+
"sample_outputs": {
60+
"properties": [
61+
{
62+
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
63+
}
64+
]
65+
}
66+
}
67+
}

src/sagemaker/serve/utils/exceptions.py

+9
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,12 @@ class SkipTuningComboException(ModelBuilderException):
6060

6161
def __init__(self, message):
6262
super().__init__(message=message)
63+
64+
65+
class TaskNotFoundException(ModelBuilderException):
66+
"""Raise when HuggingFace task could not be found"""
67+
68+
fmt = "Error Message: {message}"
69+
70+
def __init__(self, message):
71+
super().__init__(message=message)

src/sagemaker/serve/utils/task.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve task fallback input/output schema"""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import os
18+
from typing import Any, Tuple
19+
20+
21+
def retrieve_local_schemas(task: str) -> Tuple[Any, Any]:
22+
"""Retrieves task sample inputs and outputs locally.
23+
24+
Args:
25+
task (str): Required, the task name
26+
27+
Returns:
28+
Tuple[Any, Any]: A tuple that contains the sample input,
29+
at index 0, and output schema, at index 1.
30+
31+
Raises:
32+
ValueError: If no tasks config found or the task does not exist in the local config.
33+
"""
34+
config_dir = os.path.dirname(os.path.dirname(__file__))
35+
task_io_config_path = os.path.join(config_dir, "schema", "task.json")
36+
try:
37+
with open(task_io_config_path) as f:
38+
task_io_config = json.load(f)
39+
task_io_schemas = task_io_config.get(task, None)
40+
41+
if task_io_schemas is None:
42+
raise ValueError(f"Could not find {task} I/O schema.")
43+
44+
sample_schema = (
45+
task_io_schemas["sample_inputs"]["properties"],
46+
task_io_schemas["sample_outputs"]["properties"],
47+
)
48+
return sample_schema
49+
except FileNotFoundError:
50+
raise ValueError("Could not find tasks config file.")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.serve.builder.model_builder import ModelBuilder
16+
from sagemaker.serve.utils import task
17+
18+
import pytest
19+
20+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
21+
from tests.integ.sagemaker.serve.constants import (
22+
PYTHON_VERSION_IS_NOT_310,
23+
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
24+
)
25+
26+
from tests.integ.timeout import timeout
27+
from tests.integ.utils import cleanup_model_resources
28+
29+
import logging
30+
31+
logger = logging.getLogger(__name__)
32+
33+
34+
def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session):
35+
model_builder = ModelBuilder(model="bert-base-uncased")
36+
37+
model = model_builder.build(sagemaker_session=sagemaker_session)
38+
39+
assert model is not None
40+
assert model_builder.schema_builder is not None
41+
42+
inputs, outputs = task.retrieve_local_schemas("fill-mask")
43+
assert model_builder.schema_builder.sample_input == inputs
44+
assert model_builder.schema_builder.sample_output == outputs
45+
46+
47+
@pytest.mark.skipif(
48+
PYTHON_VERSION_IS_NOT_310,
49+
reason="Testing Schema Builder Simplification feature",
50+
)
51+
def test_model_builder_happy_path_with_only_model_id_question_answering(
52+
sagemaker_session, gpu_instance_type
53+
):
54+
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad")
55+
56+
model = model_builder.build(sagemaker_session=sagemaker_session)
57+
58+
assert model is not None
59+
assert model_builder.schema_builder is not None
60+
61+
inputs, outputs = task.retrieve_local_schemas("question-answering")
62+
assert model_builder.schema_builder.sample_input == inputs
63+
assert model_builder.schema_builder.sample_output == outputs
64+
65+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
66+
caught_ex = None
67+
try:
68+
iam_client = sagemaker_session.boto_session.client("iam")
69+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
70+
71+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
72+
predictor = model.deploy(
73+
role=role_arn, instance_count=1, instance_type=gpu_instance_type
74+
)
75+
76+
predicted_outputs = predictor.predict(inputs)
77+
assert predicted_outputs is not None
78+
79+
except Exception as e:
80+
caught_ex = e
81+
finally:
82+
cleanup_model_resources(
83+
sagemaker_session=model_builder.sagemaker_session,
84+
model_name=model.name,
85+
endpoint_name=model.endpoint_name,
86+
)
87+
if caught_ex:
88+
logger.exception(caught_ex)
89+
assert (
90+
False
91+
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
92+
93+
94+
def test_model_builder_negative_path(sagemaker_session):
95+
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")
96+
97+
with pytest.raises(
98+
TaskNotFoundException,
99+
match="Error Message: Schema builder for text-to-image could not be found.",
100+
):
101+
model_builder.build(sagemaker_session=sagemaker_session)

tests/unit/sagemaker/serve/builder/test_model_builder.py

+92
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from sagemaker.serve.builder.model_builder import ModelBuilder
2020
from sagemaker.serve.mode.function_pointers import Mode
21+
from sagemaker.serve.utils import task
22+
from sagemaker.serve.utils.exceptions import TaskNotFoundException
2123
from sagemaker.serve.utils.types import ModelServer
2224
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
2325

@@ -985,3 +987,93 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
985987
build_result.deploy(mode=Mode.LOCAL_CONTAINER)
986988

987989
self.assertEqual(builder.mode, Mode.LOCAL_CONTAINER)
990+
991+
@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
992+
@patch("sagemaker.image_uris.retrieve")
993+
@patch("sagemaker.djl_inference.model.urllib")
994+
@patch("sagemaker.djl_inference.model.json")
995+
@patch("sagemaker.huggingface.llm_utils.urllib")
996+
@patch("sagemaker.huggingface.llm_utils.json")
997+
@patch("sagemaker.model_uris.retrieve")
998+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
999+
def test_build_happy_path_when_schema_builder_not_present(
1000+
self,
1001+
mock_serveSettings,
1002+
mock_model_uris_retrieve,
1003+
mock_llm_utils_json,
1004+
mock_llm_utils_urllib,
1005+
mock_model_json,
1006+
mock_model_urllib,
1007+
mock_image_uris_retrieve,
1008+
mock_hf_model,
1009+
):
1010+
# Setup mocks
1011+
1012+
mock_setting_object = mock_serveSettings.return_value
1013+
mock_setting_object.role_arn = mock_role_arn
1014+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1015+
1016+
# HF Pipeline Tag
1017+
mock_model_uris_retrieve.side_effect = KeyError
1018+
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"}
1019+
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1020+
1021+
# HF Model config
1022+
mock_model_json.load.return_value = {"some": "config"}
1023+
mock_model_urllib.request.Request.side_effect = Mock()
1024+
1025+
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1026+
1027+
model_builder = ModelBuilder(model="meta-llama/Llama-2-7b-hf")
1028+
model_builder.build(sagemaker_session=mock_session)
1029+
1030+
self.assertIsNotNone(model_builder.schema_builder)
1031+
sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation")
1032+
self.assertEqual(
1033+
sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"]
1034+
)
1035+
self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output)
1036+
1037+
@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
1038+
@patch("sagemaker.image_uris.retrieve")
1039+
@patch("sagemaker.djl_inference.model.urllib")
1040+
@patch("sagemaker.djl_inference.model.json")
1041+
@patch("sagemaker.huggingface.llm_utils.urllib")
1042+
@patch("sagemaker.huggingface.llm_utils.json")
1043+
@patch("sagemaker.model_uris.retrieve")
1044+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1045+
def test_build_negative_path_when_schema_builder_not_present(
1046+
self,
1047+
mock_serveSettings,
1048+
mock_model_uris_retrieve,
1049+
mock_llm_utils_json,
1050+
mock_llm_utils_urllib,
1051+
mock_model_json,
1052+
mock_model_urllib,
1053+
mock_image_uris_retrieve,
1054+
mock_hf_model,
1055+
):
1056+
# Setup mocks
1057+
1058+
mock_setting_object = mock_serveSettings.return_value
1059+
mock_setting_object.role_arn = mock_role_arn
1060+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1061+
1062+
# HF Pipeline Tag
1063+
mock_model_uris_retrieve.side_effect = KeyError
1064+
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-to-image"}
1065+
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1066+
1067+
# HF Model config
1068+
mock_model_json.load.return_value = {"some": "config"}
1069+
mock_model_urllib.request.Request.side_effect = Mock()
1070+
1071+
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1072+
1073+
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")
1074+
1075+
self.assertRaisesRegexp(
1076+
TaskNotFoundException,
1077+
"Error Message: Schema builder for text-to-image could not be found.",
1078+
lambda: model_builder.build(sagemaker_session=mock_session),
1079+
)

0 commit comments

Comments
 (0)