@@ -1072,7 +1072,7 @@ def test_build_negative_path_when_schema_builder_not_present(
1072
1072
1073
1073
model_builder = ModelBuilder (model = "CompVis/stable-diffusion-v1-4" )
1074
1074
1075
- self .assertRaisesRegexp (
1075
+ self .assertRaisesRegex (
1076
1076
TaskNotFoundException ,
1077
1077
"Error Message: Schema builder for text-to-image could not be found." ,
1078
1078
lambda : model_builder .build (sagemaker_session = mock_session ),
@@ -1131,7 +1131,7 @@ def test_build_happy_path_override_with_task_provided(
1131
1131
@patch ("sagemaker.huggingface.llm_utils.json" )
1132
1132
@patch ("sagemaker.model_uris.retrieve" )
1133
1133
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1134
- def test_build_negative_path_override_with_task_provided (
1134
+ def test_build_task_override_with_invalid_task_provided (
1135
1135
self ,
1136
1136
mock_serveSettings ,
1137
1137
mock_model_uris_retrieve ,
@@ -1157,11 +1157,39 @@ def test_build_negative_path_override_with_task_provided(
1157
1157
mock_model_urllib .request .Request .side_effect = Mock ()
1158
1158
1159
1159
mock_image_uris_retrieve .return_value = "https://some-image-uri"
1160
+ model_ids_with_invalid_task = ["bert-base-uncased:invalid-task" , "bert-base-uncased:" ]
1161
+ for model_id in model_ids_with_invalid_task :
1162
+ model_builder = ModelBuilder (model = model_id )
1163
+
1164
+ provided_task = model_id .split (":" )[1 ]
1165
+ self .assertRaisesRegex (
1166
+ TaskNotFoundException ,
1167
+ f"Error Message: Schema builder for { provided_task } could not be found." ,
1168
+ lambda : model_builder .build (sagemaker_session = mock_session ),
1169
+ )
1160
1170
1161
- model_builder = ModelBuilder (model = "bert-base-uncased:invalid-task" )
1171
+ @patch ("sagemaker.image_uris.retrieve" )
1172
+ @patch ("sagemaker.model_uris.retrieve" )
1173
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1174
+ def test_build_task_override_with_invalid_model_provided (
1175
+ self ,
1176
+ mock_serveSettings ,
1177
+ mock_model_uris_retrieve ,
1178
+ mock_image_uris_retrieve ,
1179
+ ):
1180
+ # Setup mocks
1162
1181
1163
- self .assertRaisesRegexp (
1164
- TaskNotFoundException ,
1165
- "Error Message: Schema builder for invalid-task could not be found." ,
1166
- lambda : model_builder .build (sagemaker_session = mock_session ),
1167
- )
1182
+ mock_setting_object = mock_serveSettings .return_value
1183
+ mock_setting_object .role_arn = mock_role_arn
1184
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
1185
+
1186
+ # HF Pipeline Tag
1187
+ mock_model_uris_retrieve .side_effect = KeyError
1188
+
1189
+ mock_image_uris_retrieve .return_value = "https://some-image-uri"
1190
+ invalid_model_ids_with_task = [":fill-mask" , "bert-base-uncased;fill-mask" ]
1191
+
1192
+ for model_id in invalid_model_ids_with_task :
1193
+ model_builder = ModelBuilder (model = model_id )
1194
+ with self .assertRaises (Exception ):
1195
+ model_builder .build (sagemaker_session = mock_session )
0 commit comments