@@ -1076,7 +1076,7 @@ def test_build_negative_path_when_schema_builder_not_present(
1076
1076
1077
1077
model_builder = ModelBuilder (model = "CompVis/stable-diffusion-v1-4" )
1078
1078
1079
- self .assertRaisesRegexp (
1079
+ self .assertRaisesRegex (
1080
1080
TaskNotFoundException ,
1081
1081
"Error Message: Schema builder for text-to-image could not be found." ,
1082
1082
lambda : model_builder .build (sagemaker_session = mock_session ),
@@ -1593,3 +1593,126 @@ def test_total_inference_model_size_mib_throws(
1593
1593
model_builder .build (sagemaker_session = mock_session )
1594
1594
1595
1595
self .assertEqual (model_builder ._can_fit_on_single_gpu (), False )
1596
+
1597
+ @patch ("sagemaker.serve.builder.tgi_builder.HuggingFaceModel" )
1598
+ @patch ("sagemaker.image_uris.retrieve" )
1599
+ @patch ("sagemaker.djl_inference.model.urllib" )
1600
+ @patch ("sagemaker.djl_inference.model.json" )
1601
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
1602
+ @patch ("sagemaker.huggingface.llm_utils.json" )
1603
+ @patch ("sagemaker.model_uris.retrieve" )
1604
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1605
+ def test_build_happy_path_override_with_task_provided (
1606
+ self ,
1607
+ mock_serveSettings ,
1608
+ mock_model_uris_retrieve ,
1609
+ mock_llm_utils_json ,
1610
+ mock_llm_utils_urllib ,
1611
+ mock_model_json ,
1612
+ mock_model_urllib ,
1613
+ mock_image_uris_retrieve ,
1614
+ mock_hf_model ,
1615
+ ):
1616
+ # Setup mocks
1617
+
1618
+ mock_setting_object = mock_serveSettings .return_value
1619
+ mock_setting_object .role_arn = mock_role_arn
1620
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
1621
+
1622
+ # HF Pipeline Tag
1623
+ mock_model_uris_retrieve .side_effect = KeyError
1624
+ mock_llm_utils_json .load .return_value = {"pipeline_tag" : "fill-mask" }
1625
+ mock_llm_utils_urllib .request .Request .side_effect = Mock ()
1626
+
1627
+ # HF Model config
1628
+ mock_model_json .load .return_value = {"some" : "config" }
1629
+ mock_model_urllib .request .Request .side_effect = Mock ()
1630
+
1631
+ mock_image_uris_retrieve .return_value = "https://some-image-uri"
1632
+
1633
+ model_builder = ModelBuilder (
1634
+ model = "bert-base-uncased" , model_metadata = {"HF_TASK" : "text-generation" }
1635
+ )
1636
+ model_builder .build (sagemaker_session = mock_session )
1637
+
1638
+ self .assertIsNotNone (model_builder .schema_builder )
1639
+ sample_inputs , sample_outputs = task .retrieve_local_schemas ("text-generation" )
1640
+ self .assertEqual (
1641
+ sample_inputs ["inputs" ], model_builder .schema_builder .sample_input ["inputs" ]
1642
+ )
1643
+ self .assertEqual (sample_outputs , model_builder .schema_builder .sample_output )
1644
+
1645
+ @patch ("sagemaker.image_uris.retrieve" )
1646
+ @patch ("sagemaker.djl_inference.model.urllib" )
1647
+ @patch ("sagemaker.djl_inference.model.json" )
1648
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
1649
+ @patch ("sagemaker.huggingface.llm_utils.json" )
1650
+ @patch ("sagemaker.model_uris.retrieve" )
1651
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1652
+ def test_build_task_override_with_invalid_task_provided (
1653
+ self ,
1654
+ mock_serveSettings ,
1655
+ mock_model_uris_retrieve ,
1656
+ mock_llm_utils_json ,
1657
+ mock_llm_utils_urllib ,
1658
+ mock_model_json ,
1659
+ mock_model_urllib ,
1660
+ mock_image_uris_retrieve ,
1661
+ ):
1662
+ # Setup mocks
1663
+
1664
+ mock_setting_object = mock_serveSettings .return_value
1665
+ mock_setting_object .role_arn = mock_role_arn
1666
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
1667
+
1668
+ # HF Pipeline Tag
1669
+ mock_model_uris_retrieve .side_effect = KeyError
1670
+ mock_llm_utils_json .load .return_value = {"pipeline_tag" : "fill-mask" }
1671
+ mock_llm_utils_urllib .request .Request .side_effect = Mock ()
1672
+
1673
+ # HF Model config
1674
+ mock_model_json .load .return_value = {"some" : "config" }
1675
+ mock_model_urllib .request .Request .side_effect = Mock ()
1676
+
1677
+ mock_image_uris_retrieve .return_value = "https://some-image-uri"
1678
+ model_ids_with_invalid_task = {
1679
+ "bert-base-uncased" : "invalid-task" ,
1680
+ "bert-large-uncased-whole-word-masking-finetuned-squad" : "" ,
1681
+ }
1682
+ for model_id in model_ids_with_invalid_task :
1683
+ provided_task = model_ids_with_invalid_task [model_id ]
1684
+ model_builder = ModelBuilder (model = model_id , model_metadata = {"HF_TASK" : provided_task })
1685
+
1686
+ self .assertRaisesRegex (
1687
+ TaskNotFoundException ,
1688
+ f"Error Message: Schema builder for { provided_task } could not be found." ,
1689
+ lambda : model_builder .build (sagemaker_session = mock_session ),
1690
+ )
1691
+
1692
+ @patch ("sagemaker.image_uris.retrieve" )
1693
+ @patch ("sagemaker.model_uris.retrieve" )
1694
+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1695
+ def test_build_task_override_with_invalid_model_provided (
1696
+ self ,
1697
+ mock_serveSettings ,
1698
+ mock_model_uris_retrieve ,
1699
+ mock_image_uris_retrieve ,
1700
+ ):
1701
+ # Setup mocks
1702
+
1703
+ mock_setting_object = mock_serveSettings .return_value
1704
+ mock_setting_object .role_arn = mock_role_arn
1705
+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
1706
+
1707
+ # HF Pipeline Tag
1708
+ mock_model_uris_retrieve .side_effect = KeyError
1709
+
1710
+ mock_image_uris_retrieve .return_value = "https://some-image-uri"
1711
+ invalid_model_id = ""
1712
+ provided_task = "fill-mask"
1713
+
1714
+ model_builder = ModelBuilder (
1715
+ model = invalid_model_id , model_metadata = {"HF_TASK" : provided_task }
1716
+ )
1717
+ with self .assertRaises (Exception ):
1718
+ model_builder .build (sagemaker_session = mock_session )
0 commit comments