Skip to content

Commit 15a501d

Browse files
author
Xiong Zeng
committed
Add extra unit tests for invalid inputs
1 parent 7680368 commit 15a501d

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

src/sagemaker/huggingface/llm_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] =
8181
Returns:
8282
dict: The model metadata retrieved with the HuggingFace API
8383
"""
84-
84+
if len(model_id) == 0:
85+
raise ValueError("Model ID is empty. Please provide a valid Model ID.")
8586
hf_model_metadata_url = f"https://huggingface.co/api/models/{model_id}"
8687
hf_model_metadata_json = None
8788
try:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,9 +627,8 @@ def build(
627627

628628
if model_task is None:
629629
model_task = hf_model_md.get("pipeline_tag")
630-
if self.schema_builder is None and model_task:
630+
if self.schema_builder is None and model_task is not None:
631631
self._schema_builder_init(model_task)
632-
633632
if model_task == "text-generation": # pylint: disable=R1705
634633
return self._build_for_tgi()
635634
else:

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def test_build_negative_path_when_schema_builder_not_present(
10721072

10731073
model_builder = ModelBuilder(model="CompVis/stable-diffusion-v1-4")
10741074

1075-
self.assertRaisesRegexp(
1075+
self.assertRaisesRegex(
10761076
TaskNotFoundException,
10771077
"Error Message: Schema builder for text-to-image could not be found.",
10781078
lambda: model_builder.build(sagemaker_session=mock_session),
@@ -1131,7 +1131,7 @@ def test_build_happy_path_override_with_task_provided(
11311131
@patch("sagemaker.huggingface.llm_utils.json")
11321132
@patch("sagemaker.model_uris.retrieve")
11331133
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1134-
def test_build_negative_path_override_with_task_provided(
1134+
def test_build_task_override_with_invalid_task_provided(
11351135
self,
11361136
mock_serveSettings,
11371137
mock_model_uris_retrieve,
@@ -1157,11 +1157,39 @@ def test_build_negative_path_override_with_task_provided(
11571157
mock_model_urllib.request.Request.side_effect = Mock()
11581158

11591159
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1160+
model_ids_with_invalid_task = ["bert-base-uncased:invalid-task", "bert-base-uncased:"]
1161+
for model_id in model_ids_with_invalid_task:
1162+
model_builder = ModelBuilder(model=model_id)
1163+
1164+
provided_task = model_id.split(":")[1]
1165+
self.assertRaisesRegex(
1166+
TaskNotFoundException,
1167+
f"Error Message: Schema builder for {provided_task} could not be found.",
1168+
lambda: model_builder.build(sagemaker_session=mock_session),
1169+
)
11601170

1161-
model_builder = ModelBuilder(model="bert-base-uncased:invalid-task")
1171+
@patch("sagemaker.image_uris.retrieve")
1172+
@patch("sagemaker.model_uris.retrieve")
1173+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1174+
def test_build_task_override_with_invalid_model_provided(
1175+
self,
1176+
mock_serveSettings,
1177+
mock_model_uris_retrieve,
1178+
mock_image_uris_retrieve,
1179+
):
1180+
# Setup mocks
11621181

1163-
self.assertRaisesRegexp(
1164-
TaskNotFoundException,
1165-
"Error Message: Schema builder for invalid-task could not be found.",
1166-
lambda: model_builder.build(sagemaker_session=mock_session),
1167-
)
1182+
mock_setting_object = mock_serveSettings.return_value
1183+
mock_setting_object.role_arn = mock_role_arn
1184+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1185+
1186+
# HF Pipeline Tag
1187+
mock_model_uris_retrieve.side_effect = KeyError
1188+
1189+
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1190+
invalid_model_ids_with_task = [":fill-mask", "bert-base-uncased;fill-mask"]
1191+
1192+
for model_id in invalid_model_ids_with_task:
1193+
model_builder = ModelBuilder(model=model_id)
1194+
with self.assertRaises(Exception):
1195+
model_builder.build(sagemaker_session=mock_session)

0 commit comments

Comments
 (0)