Skip to content

Commit a72dc6d

Browse files
feature: Changes to support remote schema retrieval for task types (question-answering, fill-mask) and added e2e tests for both local and remote hf schema logic. (#4572)
* feature: Add and use sagemaker_schema_inference_artifacts dependency for huggingface in schema builder (question-answering only) * feature: Switch to remote schema for hf tasks question-answering and fill-mask with appropriate e2e integ tests. * Format fixes * Fix pylint * Format fixes * Remove speech recognition serializer fixes * Test fixes --------- Co-authored-by: Shailav Taneja <none> Co-authored-by: Shailav Taneja <[email protected]>
1 parent a211c76 commit a72dc6d

File tree

6 files changed

+128
-86
lines changed

6 files changed

+128
-86
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
accelerate>=0.24.1,<=0.27.0
2+
sagemaker_schema_inference_artifacts>=0.0.5

src/sagemaker/serve/builder/model_builder.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def build( # pylint: disable=R0911
637637
if model_task is None:
638638
model_task = hf_model_md.get("pipeline_tag")
639639
if self.schema_builder is None and model_task is not None:
640-
self._schema_builder_init(model_task)
640+
self._hf_schema_builder_init(model_task)
641641
if model_task == "text-generation": # pylint: disable=R1705
642642
return self._build_for_tgi()
643643
elif self._can_fit_on_single_gpu():
@@ -704,8 +704,8 @@ def validate(self, model_dir: str) -> Type[bool]:
704704

705705
return get_metadata(model_dir)
706706

707-
def _schema_builder_init(self, model_task: str):
708-
"""Initialize the schema builder
707+
def _hf_schema_builder_init(self, model_task: str):
708+
"""Initialize the schema builder for the given HF_TASK
709709
710710
Args:
711711
model_task (str): Required, the task name
@@ -714,10 +714,29 @@ def _schema_builder_init(self, model_task: str):
714714
TaskNotFoundException: If the I/O schema for the given task is not found.
715715
"""
716716
try:
717-
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
717+
try:
718+
sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task)
719+
except ValueError:
720+
# samples could not be loaded locally, try to fetch remote hf schema
721+
from sagemaker_schema_inference_artifacts.huggingface import remote_schema_retriever
722+
723+
if model_task in ("text-to-image", "automatic-speech-recognition"):
724+
logger.warning(
725+
"HF SchemaBuilder for %s is in beta mode, and is not guaranteed to work "
726+
"with all models at this time.",
727+
model_task,
728+
)
729+
remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever()
730+
(
731+
sample_inputs,
732+
sample_outputs,
733+
) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task)
718734
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
719735
except ValueError:
720-
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")
736+
raise TaskNotFoundException(
737+
f"HuggingFace Schema builder samples for {model_task} could not be found "
738+
f"locally or via remote."
739+
)
721740

722741
def _can_fit_on_single_gpu(self) -> Type[bool]:
723742
"""Check if model can fit on a single GPU

src/sagemaker/serve/schema/task.json

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,4 @@
11
{
2-
"fill-mask": {
3-
"sample_inputs": {
4-
"properties": {
5-
"inputs": "Paris is the [MASK] of France.",
6-
"parameters": {}
7-
}
8-
},
9-
"sample_outputs": {
10-
"properties": [
11-
{
12-
"sequence": "Paris is the capital of France.",
13-
"score": 0.7
14-
}
15-
]
16-
}
17-
},
18-
"question-answering": {
19-
"sample_inputs": {
20-
"properties": {
21-
"context": "I have a German Shepherd dog, named Coco.",
22-
"question": "What is my dog's breed?"
23-
}
24-
},
25-
"sample_outputs": {
26-
"properties": [
27-
{
28-
"answer": "German Shepherd",
29-
"score": 0.972,
30-
"start": 9,
31-
"end": 24
32-
}
33-
]
34-
}
35-
},
362
"text-classification": {
373
"sample_inputs": {
384
"properties": {

tests/integ/sagemaker/serve/test_schema_builder.py

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker.serve.utils.exceptions import TaskNotFoundException
21+
from sagemaker_schema_inference_artifacts.huggingface import remote_schema_retriever
2122
from tests.integ.sagemaker.serve.constants import (
2223
PYTHON_VERSION_IS_NOT_310,
2324
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
@@ -31,35 +32,73 @@
3132
logger = logging.getLogger(__name__)
3233

3334

34-
def test_model_builder_happy_path_with_only_model_id_fill_mask(sagemaker_session):
35-
model_builder = ModelBuilder(model="bert-base-uncased")
35+
def test_model_builder_happy_path_with_only_model_id_text_generation(sagemaker_session):
36+
model_builder = ModelBuilder(model="HuggingFaceH4/zephyr-7b-beta")
3637

3738
model = model_builder.build(sagemaker_session=sagemaker_session)
3839

3940
assert model is not None
4041
assert model_builder.schema_builder is not None
4142

42-
inputs, outputs = task.retrieve_local_schemas("fill-mask")
43-
assert model_builder.schema_builder.sample_input == inputs
43+
inputs, outputs = task.retrieve_local_schemas("text-generation")
44+
assert model_builder.schema_builder.sample_input["inputs"] == inputs["inputs"]
4445
assert model_builder.schema_builder.sample_output == outputs
4546

4647

48+
def test_model_builder_negative_path(sagemaker_session):
49+
# A model-task combo unsupported by both the local and remote schema fallback options. (eg: text-to-video)
50+
model_builder = ModelBuilder(model="ByteDance/AnimateDiff-Lightning")
51+
with pytest.raises(
52+
TaskNotFoundException,
53+
match="Error Message: HuggingFace Schema builder samples for text-to-video could not be found locally or "
54+
"via remote.",
55+
):
56+
model_builder.build(sagemaker_session=sagemaker_session)
57+
58+
4759
@pytest.mark.skipif(
4860
PYTHON_VERSION_IS_NOT_310,
49-
reason="Testing Schema Builder Simplification feature",
61+
reason="Testing Schema Builder Simplification feature - Local Schema",
5062
)
51-
def test_model_builder_happy_path_with_only_model_id_question_answering(
52-
sagemaker_session, gpu_instance_type
63+
@pytest.mark.parametrize(
64+
"model_id, task_provided, instance_type_provided, container_startup_timeout",
65+
[
66+
(
67+
"distilbert/distilbert-base-uncased-finetuned-sst-2-english",
68+
"text-classification",
69+
"ml.m5.xlarge",
70+
None,
71+
),
72+
(
73+
"cardiffnlp/twitter-roberta-base-sentiment-latest",
74+
"text-classification",
75+
"ml.m5.xlarge",
76+
None,
77+
),
78+
("HuggingFaceH4/zephyr-7b-beta", "text-generation", "ml.g5.2xlarge", 900),
79+
("HuggingFaceH4/zephyr-7b-alpha", "text-generation", "ml.g5.2xlarge", 900),
80+
],
81+
)
82+
def test_model_builder_happy_path_with_task_provided_local_schema_mode(
83+
model_id, task_provided, sagemaker_session, instance_type_provided, container_startup_timeout
5384
):
54-
model_builder = ModelBuilder(model="bert-large-uncased-whole-word-masking-finetuned-squad")
85+
model_builder = ModelBuilder(
86+
model=model_id,
87+
model_metadata={"HF_TASK": task_provided},
88+
instance_type=instance_type_provided,
89+
)
5590

5691
model = model_builder.build(sagemaker_session=sagemaker_session)
5792

5893
assert model is not None
5994
assert model_builder.schema_builder is not None
6095

61-
inputs, outputs = task.retrieve_local_schemas("question-answering")
62-
assert model_builder.schema_builder.sample_input == inputs
96+
inputs, outputs = task.retrieve_local_schemas(task_provided)
97+
if task_provided == "text-generation":
98+
# ignore 'tokens' and other metadata in this case
99+
assert model_builder.schema_builder.sample_input["inputs"] == inputs["inputs"]
100+
else:
101+
assert model_builder.schema_builder.sample_input == inputs
63102
assert model_builder.schema_builder.sample_output == outputs
64103

65104
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
@@ -69,9 +108,17 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(
69108
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
70109

71110
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
72-
predictor = model.deploy(
73-
role=role_arn, instance_count=1, instance_type=gpu_instance_type
74-
)
111+
if container_startup_timeout:
112+
predictor = model.deploy(
113+
role=role_arn,
114+
instance_count=1,
115+
instance_type=instance_type_provided,
116+
container_startup_health_check_timeout=container_startup_timeout,
117+
)
118+
else:
119+
predictor = model.deploy(
120+
role=role_arn, instance_count=1, instance_type=instance_type_provided
121+
)
75122

76123
predicted_outputs = predictor.predict(inputs)
77124
assert predicted_outputs is not None
@@ -91,38 +138,38 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(
91138
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
92139

93140

94-
def test_model_builder_negative_path(sagemaker_session):
95-
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")
96-
97-
with pytest.raises(
98-
TaskNotFoundException,
99-
match="Error Message: Schema builder for text-to-image could not be found.",
100-
):
101-
model_builder.build(sagemaker_session=sagemaker_session)
102-
103-
104141
@pytest.mark.skipif(
105142
PYTHON_VERSION_IS_NOT_310,
106-
reason="Testing Schema Builder Simplification feature",
143+
reason="Testing Schema Builder Simplification feature - Remote Schema",
107144
)
108145
@pytest.mark.parametrize(
109-
"model_id, task_provided",
146+
"model_id, task_provided, instance_type_provided",
110147
[
111-
("bert-base-uncased", "fill-mask"),
112-
("bert-large-uncased-whole-word-masking-finetuned-squad", "question-answering"),
148+
("google-bert/bert-base-uncased", "fill-mask", "ml.m5.xlarge"),
149+
("google-bert/bert-base-cased", "fill-mask", "ml.m5.xlarge"),
150+
(
151+
"google-bert/bert-large-uncased-whole-word-masking-finetuned-squad",
152+
"question-answering",
153+
"ml.m5.xlarge",
154+
),
155+
("deepset/roberta-base-squad2", "question-answering", "ml.m5.xlarge"),
113156
],
114157
)
115-
def test_model_builder_happy_path_with_task_provided(
116-
model_id, task_provided, sagemaker_session, gpu_instance_type
158+
def test_model_builder_happy_path_with_task_provided_remote_schema_mode(
159+
model_id, task_provided, sagemaker_session, instance_type_provided
117160
):
118-
model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided})
119-
161+
model_builder = ModelBuilder(
162+
model=model_id,
163+
model_metadata={"HF_TASK": task_provided},
164+
instance_type=instance_type_provided,
165+
)
120166
model = model_builder.build(sagemaker_session=sagemaker_session)
121167

122168
assert model is not None
123169
assert model_builder.schema_builder is not None
124170

125-
inputs, outputs = task.retrieve_local_schemas(task_provided)
171+
remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever()
172+
inputs, outputs = remote_hf_schema_helper.get_resolved_hf_schema_for_task(task_provided)
126173
assert model_builder.schema_builder.sample_input == inputs
127174
assert model_builder.schema_builder.sample_output == outputs
128175

@@ -134,7 +181,7 @@ def test_model_builder_happy_path_with_task_provided(
134181

135182
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
136183
predictor = model.deploy(
137-
role=role_arn, instance_count=1, instance_type=gpu_instance_type
184+
role=role_arn, instance_count=1, instance_type=instance_type_provided
138185
)
139186

140187
predicted_outputs = predictor.predict(inputs)
@@ -162,6 +209,7 @@ def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
162209

163210
with pytest.raises(
164211
TaskNotFoundException,
165-
match="Error Message: Schema builder for invalid-task could not be found.",
212+
match="Error Message: HuggingFace Schema builder samples for invalid-task could not be found locally or "
213+
"via remote.",
166214
):
167215
model_builder.build(sagemaker_session=sagemaker_session)

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def test_build_negative_path_when_schema_builder_not_present(
10621062

10631063
# HF Pipeline Tag
10641064
mock_model_uris_retrieve.side_effect = KeyError
1065-
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-to-image"}
1065+
mock_llm_utils_json.load.return_value = {"pipeline_tag": "unsupported-task"}
10661066
mock_llm_utils_urllib.request.Request.side_effect = Mock()
10671067

10681068
# HF Model config
@@ -1075,7 +1075,8 @@ def test_build_negative_path_when_schema_builder_not_present(
10751075

10761076
self.assertRaisesRegex(
10771077
TaskNotFoundException,
1078-
"Error Message: Schema builder for text-to-image could not be found.",
1078+
"Error Message: HuggingFace Schema builder samples for unsupported-task could not be found locally or via "
1079+
"remote.",
10791080
lambda: model_builder.build(sagemaker_session=mock_session),
10801081
)
10811082

@@ -1627,7 +1628,8 @@ def test_build_task_override_with_invalid_task_provided(
16271628

16281629
self.assertRaisesRegex(
16291630
TaskNotFoundException,
1630-
f"Error Message: Schema builder for {provided_task} could not be found.",
1631+
f"Error Message: HuggingFace Schema builder samples for {provided_task} could not be found locally or "
1632+
f"via remote.",
16311633
lambda: model_builder.build(sagemaker_session=mock_session),
16321634
)
16331635

tests/unit/sagemaker/serve/utils/test_task.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,29 @@
1818

1919
from sagemaker.serve.utils import task
2020

21-
EXPECTED_INPUTS = {"inputs": "Paris is the [MASK] of France.", "parameters": {}}
22-
EXPECTED_OUTPUTS = [{"sequence": "Paris is the capital of France.", "score": 0.7}]
2321
HF_INVALID_TASK = "not-present-task"
2422

2523

26-
def test_retrieve_local_schemas_success():
27-
inputs, outputs = task.retrieve_local_schemas("fill-mask")
24+
def test_retrieve_local_schemas_text_generation_success():
25+
inputs, outputs = task.retrieve_local_schemas("text-generation")
2826

29-
assert inputs == EXPECTED_INPUTS
30-
assert outputs == EXPECTED_OUTPUTS
27+
assert inputs == {"inputs": "Hello, I'm a language model", "parameters": {}}
28+
assert outputs == [
29+
{
30+
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to "
31+
"meet my wife or come home she told me that my"
32+
}
33+
]
3134

3235

33-
def test_retrieve_local_schemas_text_generation_success():
34-
inputs, outputs = task.retrieve_local_schemas("text-generation")
36+
def test_retrieve_local_schemas_text_classification_success():
37+
inputs, outputs = task.retrieve_local_schemas("text-classification")
3538

36-
assert inputs is not None
37-
assert outputs is not None
39+
assert inputs == {
40+
"inputs": "Where is the capital of France?, Paris is the capital of France.",
41+
"parameters": {},
42+
}
43+
assert outputs == [{"label": "entailment", "score": 0.997}]
3844

3945

4046
def test_retrieve_local_schemas_throws():

0 commit comments

Comments
 (0)