Skip to content

Commit bdeb84b

Browse files
author
Xiong Zeng
committed
Add override logic in ModelBuilder with task provided
1 parent 9e4e2ec commit bdeb84b

File tree

4 files changed

+172
-22
lines changed

4 files changed

+172
-22
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
118118
into a stream. All translations between the server and the client are handled
119119
automatically with the specified input and output.
120120
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
121-
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or
121+
inference) or a HuggingFace/JumpStart Model ID (followed by ``:task`` if you need
122+
to override the task, e.g. bert-base-uncased:fill-mask). Either ``model`` or
122123
``inference_spec`` is required for the model builder to build the artifact.
123124
inference_spec (InferenceSpec): The inference spec file with your customized
124125
``invoke`` and ``load`` functions.
@@ -205,6 +206,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
205206
"help": (
206207
'Model object with "predict" method to perform inference '
207208
"or HuggingFace/JumpStart Model ID"
209+
"or if you need to override task, provide input as ModelID:Task"
208210
)
209211
},
210212
)
@@ -610,6 +612,10 @@ def build(
610612
self._is_custom_image_uri = self.image_uri is not None
611613

612614
if isinstance(self.model, str):
615+
model_task = None
616+
if ":" in self.model:
617+
model_task = self.model.split(":")[1]
618+
self.model = self.model.split(":")[0]
613619
if self._is_jumpstart_model_id():
614620
return self._build_for_jumpstart()
615621
if self._is_djl(): # pylint: disable=R1705
@@ -619,7 +625,8 @@ def build(
619625
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
620626
)
621627

622-
model_task = hf_model_md.get("pipeline_tag")
628+
if model_task is None:
629+
model_task = hf_model_md.get("pipeline_tag")
623630
if self.schema_builder is None and model_task:
624631
self._schema_builder_init(model_task)
625632

src/sagemaker/serve/schema/task.json

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
{
22
"fill-mask": {
3-
"sample_inputs": {
3+
"sample_inputs": {
44
"properties": {
5-
"inputs": "Paris is the <mask> of France.",
5+
"inputs": "Paris is the [MASK] of France.",
66
"parameters": {}
77
}
88
},
9-
"sample_outputs": {
9+
"sample_outputs": {
1010
"properties": [
1111
{
1212
"sequence": "Paris is the capital of France.",
1313
"score": 0.7
1414
}
1515
]
1616
}
17-
},
17+
},
1818
"question-answering": {
19-
"sample_inputs": {
19+
"sample_inputs": {
2020
"properties": {
2121
"context": "I have a German Shepherd dog, named Coco.",
2222
"question": "What is my dog's breed?"
2323
}
24-
},
25-
"sample_outputs": {
24+
},
25+
"sample_outputs": {
2626
"properties": [
2727
{
2828
"answer": "German Shepherd",
@@ -32,36 +32,36 @@
3232
}
3333
]
3434
}
35-
},
35+
},
3636
"text-classification": {
37-
"sample_inputs": {
37+
"sample_inputs": {
3838
"properties": {
3939
"inputs": "Where is the capital of France?, Paris is the capital of France.",
4040
"parameters": {}
4141
}
42-
},
43-
"sample_outputs": {
42+
},
43+
"sample_outputs": {
4444
"properties": [
4545
{
4646
"label": "entailment",
4747
"score": 0.997
4848
}
4949
]
5050
}
51-
},
52-
"text-generation": {
53-
"sample_inputs": {
51+
},
52+
"text-generation": {
53+
"sample_inputs": {
5454
"properties": {
5555
"inputs": "Hello, I'm a language model",
5656
"parameters": {}
5757
}
58-
},
59-
"sample_outputs": {
58+
},
59+
"sample_outputs": {
6060
"properties": [
61-
{
62-
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
63-
}
61+
{
62+
"generated_text": "Hello, I'm a language modeler. So while writing this, when I went out to meet my wife or come home she told me that my"
63+
}
6464
]
6565
}
66-
}
66+
}
6767
}

tests/integ/sagemaker/serve/test_schema_builder.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,58 @@ def test_model_builder_negative_path(sagemaker_session):
9999
match="Error Message: Schema builder for text-to-image could not be found.",
100100
):
101101
model_builder.build(sagemaker_session=sagemaker_session)
102+
103+
104+
@pytest.mark.skipif(
105+
PYTHON_VERSION_IS_NOT_310,
106+
reason="Testing Schema Builder Simplification feature",
107+
)
108+
def test_model_builder_happy_path_with_task_provided(sagemaker_session, gpu_instance_type):
109+
model_builder = ModelBuilder(model="bert-base-uncased:fill-mask")
110+
111+
model = model_builder.build(sagemaker_session=sagemaker_session)
112+
113+
assert model is not None
114+
assert model_builder.schema_builder is not None
115+
116+
inputs, outputs = task.retrieve_local_schemas("fill-mask")
117+
assert model_builder.schema_builder.sample_input == inputs
118+
assert model_builder.schema_builder.sample_output == outputs
119+
120+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
121+
caught_ex = None
122+
try:
123+
iam_client = sagemaker_session.boto_session.client("iam")
124+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
125+
126+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
127+
predictor = model.deploy(
128+
role=role_arn, instance_count=1, instance_type=gpu_instance_type
129+
)
130+
131+
predicted_outputs = predictor.predict(inputs)
132+
assert predicted_outputs is not None
133+
134+
except Exception as e:
135+
caught_ex = e
136+
finally:
137+
cleanup_model_resources(
138+
sagemaker_session=model_builder.sagemaker_session,
139+
model_name=model.name,
140+
endpoint_name=model.endpoint_name,
141+
)
142+
if caught_ex:
143+
logger.exception(caught_ex)
144+
assert (
145+
False
146+
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
147+
148+
149+
def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
150+
model_builder = ModelBuilder(model="bert-base-uncased:invalid-task")
151+
152+
with pytest.raises(
153+
TaskNotFoundException,
154+
match="Error Message: Schema builder for invalid-task could not be found.",
155+
):
156+
model_builder.build(sagemaker_session=sagemaker_session)

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,3 +1077,91 @@ def test_build_negative_path_when_schema_builder_not_present(
10771077
"Error Message: Schema builder for text-to-image could not be found.",
10781078
lambda: model_builder.build(sagemaker_session=mock_session),
10791079
)
1080+
1081+
@patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel")
1082+
@patch("sagemaker.image_uris.retrieve")
1083+
@patch("sagemaker.djl_inference.model.urllib")
1084+
@patch("sagemaker.djl_inference.model.json")
1085+
@patch("sagemaker.huggingface.llm_utils.urllib")
1086+
@patch("sagemaker.huggingface.llm_utils.json")
1087+
@patch("sagemaker.model_uris.retrieve")
1088+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1089+
def test_build_happy_path_override_with_task_provided(
1090+
self,
1091+
mock_serveSettings,
1092+
mock_model_uris_retrieve,
1093+
mock_llm_utils_json,
1094+
mock_llm_utils_urllib,
1095+
mock_model_json,
1096+
mock_model_urllib,
1097+
mock_image_uris_retrieve,
1098+
mock_hf_model,
1099+
):
1100+
# Setup mocks
1101+
1102+
mock_setting_object = mock_serveSettings.return_value
1103+
mock_setting_object.role_arn = mock_role_arn
1104+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1105+
1106+
# HF Pipeline Tag
1107+
mock_model_uris_retrieve.side_effect = KeyError
1108+
mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"}
1109+
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1110+
1111+
# HF Model config
1112+
mock_model_json.load.return_value = {"some": "config"}
1113+
mock_model_urllib.request.Request.side_effect = Mock()
1114+
1115+
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1116+
1117+
model_builder = ModelBuilder(model="bert-base-uncased:text-generation")
1118+
model_builder.build(sagemaker_session=mock_session)
1119+
1120+
self.assertIsNotNone(model_builder.schema_builder)
1121+
sample_inputs, sample_outputs = task.retrieve_local_schemas("text-generation")
1122+
self.assertEqual(
1123+
sample_inputs["inputs"], model_builder.schema_builder.sample_input["inputs"]
1124+
)
1125+
self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output)
1126+
1127+
@patch("sagemaker.image_uris.retrieve")
1128+
@patch("sagemaker.djl_inference.model.urllib")
1129+
@patch("sagemaker.djl_inference.model.json")
1130+
@patch("sagemaker.huggingface.llm_utils.urllib")
1131+
@patch("sagemaker.huggingface.llm_utils.json")
1132+
@patch("sagemaker.model_uris.retrieve")
1133+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1134+
def test_build_negative_path_override_with_task_provided(
1135+
self,
1136+
mock_serveSettings,
1137+
mock_model_uris_retrieve,
1138+
mock_llm_utils_json,
1139+
mock_llm_utils_urllib,
1140+
mock_model_json,
1141+
mock_model_urllib,
1142+
mock_image_uris_retrieve,
1143+
):
1144+
# Setup mocks
1145+
1146+
mock_setting_object = mock_serveSettings.return_value
1147+
mock_setting_object.role_arn = mock_role_arn
1148+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1149+
1150+
# HF Pipeline Tag
1151+
mock_model_uris_retrieve.side_effect = KeyError
1152+
mock_llm_utils_json.load.return_value = {"pipeline_tag": "fill-mask"}
1153+
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1154+
1155+
# HF Model config
1156+
mock_model_json.load.return_value = {"some": "config"}
1157+
mock_model_urllib.request.Request.side_effect = Mock()
1158+
1159+
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1160+
1161+
model_builder = ModelBuilder(model="bert-base-uncased:invalid-task")
1162+
1163+
self.assertRaisesRegexp(
1164+
TaskNotFoundException,
1165+
"Error Message: Schema builder for invalid-task could not be found.",
1166+
lambda: model_builder.build(sagemaker_session=mock_session),
1167+
)

0 commit comments

Comments
 (0)