@@ -718,6 +718,80 @@ def test_get_deployment_config(
718
718
719
719
self .assertEqual (builder .get_deployment_config (), expected )
720
720
721
+ @patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
722
+ @patch (
723
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
724
+ return_value = True ,
725
+ )
726
+ @patch (
727
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model" ,
728
+ return_value = MagicMock (),
729
+ )
730
+ @patch (
731
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources" ,
732
+ return_value = ({"model_type" : "t5" , "n_head" : 71 }, True ),
733
+ )
734
+ @patch ("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb" , return_value = 1024 )
735
+ @patch (
736
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance" , return_value = "ml.g5.24xlarge"
737
+ )
738
+ def test_set_deployment_config (
739
+ self ,
740
+ mock_get_nb_instance ,
741
+ mock_get_ram_usage_mb ,
742
+ mock_prepare_for_tgi ,
743
+ mock_pre_trained_model ,
744
+ mock_is_jumpstart_model ,
745
+ mock_telemetry ,
746
+ ):
747
+ builder = ModelBuilder (
748
+ model = "facebook/galactica-mock-model-id" ,
749
+ schema_builder = mock_schema_builder ,
750
+ )
751
+
752
+ mock_pre_trained_model .return_value .image_uri = mock_tgi_image_uri
753
+
754
+ builder .build ()
755
+ builder .set_deployment_config ("config-1" )
756
+
757
+ mock_pre_trained_model .return_value .set_deployment_config .assert_called_with ("config-1" )
758
+
759
+ @patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
760
+ @patch (
761
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
762
+ return_value = True ,
763
+ )
764
+ @patch (
765
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model" ,
766
+ return_value = MagicMock (),
767
+ )
768
+ @patch (
769
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources" ,
770
+ return_value = ({"model_type" : "t5" , "n_head" : 71 }, True ),
771
+ )
772
+ @patch ("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb" , return_value = 1024 )
773
+ @patch (
774
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance" , return_value = "ml.g5.24xlarge"
775
+ )
776
+ def test_set_deployment_config_ex (
777
+ self ,
778
+ mock_get_nb_instance ,
779
+ mock_get_ram_usage_mb ,
780
+ mock_prepare_for_tgi ,
781
+ mock_pre_trained_model ,
782
+ mock_is_jumpstart_model ,
783
+ mock_telemetry ,
784
+ ):
785
+ mock_pre_trained_model .return_value .image_uri = mock_tgi_image_uri
786
+
787
+ self .assertRaisesRegex (
788
+ Exception ,
789
+ "Cannot set deployment config to an uninitialized model." ,
790
+ lambda : ModelBuilder (
791
+ model = "facebook/galactica-mock-model-id" , schema_builder = mock_schema_builder
792
+ ).set_deployment_config ("config-2" ),
793
+ )
794
+
721
795
@patch ("sagemaker.serve.builder.jumpstart_builder._capture_telemetry" , side_effect = None )
722
796
@patch (
723
797
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id" ,
@@ -759,11 +833,3 @@ def test_display_benchmark_metrics(
759
833
builder .display_benchmark_metrics ()
760
834
761
835
mock_pre_trained_model .return_value .display_benchmark_metrics .assert_called_once ()
762
-
763
- def test_display_benchmark_metrics_ex (self ):
764
- self .assertRaises (
765
- Exception ,
766
- lambda : ModelBuilder (
767
- model = "facebook/galactica-mock-model-id" , schema_builder = mock_schema_builder
768
- ).display_benchmark_metrics (),
769
- )
0 commit comments