42
42
mock_s3_model_data_url = "sample s3 data url"
43
43
mock_secret_key = "mock_secret_key"
44
44
mock_instance_type = "mock instance type"
45
+ MOCK_HF_MODEL_METADATA_JSON = {"mock_key" : "mock_value" }
45
46
46
47
supported_model_server = {
47
48
ModelServer .TORCHSERVE ,
54
55
55
56
class TestModelBuilder (unittest .TestCase ):
56
57
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
57
- def test_validation_in_progress_mode_not_supported (self , mock_serveSettings ):
58
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
59
+ @patch ("sagemaker.huggingface.llm_utils.json" )
60
+ def test_validation_in_progress_mode_not_supported (
61
+ self , mock_serveSettings , mock_urllib , mock_json
62
+ ):
63
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
64
+ mock_hf_model_metadata_url = Mock ()
65
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
66
+
58
67
builder = ModelBuilder ()
59
68
self .assertRaisesRegex (
60
69
Exception ,
@@ -66,7 +75,15 @@ def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
66
75
)
67
76
68
77
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
69
- def test_validation_cannot_set_both_model_and_inference_spec (self , mock_serveSettings ):
78
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
79
+ @patch ("sagemaker.huggingface.llm_utils.json" )
80
+ def test_validation_cannot_set_both_model_and_inference_spec (
81
+ self , mock_serveSettings , mock_urllib , mock_json
82
+ ):
83
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
84
+ mock_hf_model_metadata_url = Mock ()
85
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
86
+
70
87
builder = ModelBuilder (inference_spec = "some value" , model = Mock (spec = object ))
71
88
self .assertRaisesRegex (
72
89
Exception ,
@@ -78,7 +95,15 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
78
95
)
79
96
80
97
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
81
- def test_validation_unsupported_model_server_type (self , mock_serveSettings ):
98
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
99
+ @patch ("sagemaker.huggingface.llm_utils.json" )
100
+ def test_validation_unsupported_model_server_type (
101
+ self , mock_serveSettings , mock_urllib , mock_json
102
+ ):
103
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
104
+ mock_hf_model_metadata_url = Mock ()
105
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
106
+
82
107
builder = ModelBuilder (model_server = "invalid_model_server" )
83
108
self .assertRaisesRegex (
84
109
Exception ,
@@ -91,7 +116,15 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
91
116
)
92
117
93
118
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
94
- def test_validation_model_server_not_set_with_image_uri (self , mock_serveSettings ):
119
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
120
+ @patch ("sagemaker.huggingface.llm_utils.json" )
121
+ def test_validation_model_server_not_set_with_image_uri (
122
+ self , mock_serveSettings , mock_urllib , mock_json
123
+ ):
124
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
125
+ mock_hf_model_metadata_url = Mock ()
126
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
127
+
95
128
builder = ModelBuilder (image_uri = "image_uri" )
96
129
self .assertRaisesRegex (
97
130
Exception ,
@@ -104,9 +137,15 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104
137
)
105
138
106
139
@patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
140
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
141
+ @patch ("sagemaker.huggingface.llm_utils.json" )
107
142
def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set (
108
- self , mock_serveSettings
143
+ self , mock_serveSettings , mock_urllib , mock_json
109
144
):
145
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
146
+ mock_hf_model_metadata_url = Mock ()
147
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
148
+
110
149
builder = ModelBuilder (inference_spec = None , model = None )
111
150
self .assertRaisesRegex (
112
151
Exception ,
@@ -126,8 +165,12 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
126
165
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
127
166
@patch ("sagemaker.serve.builder.model_builder.Model" )
128
167
@patch ("os.path.exists" )
168
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
169
+ @patch ("sagemaker.huggingface.llm_utils.json" )
129
170
def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc (
130
171
self ,
172
+ mock_urllib ,
173
+ mock_json ,
131
174
mock_path_exists ,
132
175
mock_sdk_model ,
133
176
mock_sageMakerEndpointMode ,
@@ -146,6 +189,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
146
189
else None
147
190
)
148
191
192
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
193
+ mock_hf_model_metadata_url = Mock ()
194
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
195
+
149
196
mock_detect_fw_version .return_value = framework , version
150
197
151
198
mock_prepare_for_torchserve .side_effect = (
@@ -226,8 +273,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
226
273
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
227
274
@patch ("sagemaker.serve.builder.model_builder.Model" )
228
275
@patch ("os.path.exists" )
276
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
277
+ @patch ("sagemaker.huggingface.llm_utils.json" )
229
278
def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc (
230
279
self ,
280
+ mock_urllib ,
281
+ mock_json ,
231
282
mock_path_exists ,
232
283
mock_sdk_model ,
233
284
mock_sageMakerEndpointMode ,
@@ -246,6 +297,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
246
297
else None
247
298
)
248
299
300
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
301
+ mock_hf_model_metadata_url = Mock ()
302
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
303
+
249
304
mock_detect_fw_version .return_value = framework , version
250
305
251
306
mock_prepare_for_torchserve .side_effect = (
@@ -326,8 +381,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
326
381
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
327
382
@patch ("sagemaker.serve.builder.model_builder.Model" )
328
383
@patch ("os.path.exists" )
384
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
385
+ @patch ("sagemaker.huggingface.llm_utils.json" )
329
386
def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec (
330
387
self ,
388
+ mock_urllib ,
389
+ mock_json ,
331
390
mock_path_exists ,
332
391
mock_sdk_model ,
333
392
mock_sageMakerEndpointMode ,
@@ -343,6 +402,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
343
402
lambda model_path : mock_native_model if model_path == MODEL_PATH else None
344
403
)
345
404
405
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
406
+ mock_hf_model_metadata_url = Mock ()
407
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
408
+
346
409
mock_detect_fw_version .return_value = framework , version
347
410
348
411
mock_detect_container .side_effect = (
@@ -427,8 +490,12 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
427
490
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
428
491
@patch ("sagemaker.serve.builder.model_builder.Model" )
429
492
@patch ("os.path.exists" )
493
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
494
+ @patch ("sagemaker.huggingface.llm_utils.json" )
430
495
def test_build_happy_path_with_sagemakerEndpoint_mode_and_model (
431
496
self ,
497
+ mock_urllib ,
498
+ mock_json ,
432
499
mock_path_exists ,
433
500
mock_sdk_model ,
434
501
mock_sageMakerEndpointMode ,
@@ -447,6 +514,10 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
447
514
else None
448
515
)
449
516
517
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
518
+ mock_hf_model_metadata_url = Mock ()
519
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
520
+
450
521
mock_detect_fw_version .return_value = framework , version
451
522
452
523
mock_prepare_for_torchserve .side_effect = (
@@ -530,8 +601,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
530
601
@patch ("sagemaker.serve.builder.model_builder.SageMakerEndpointMode" )
531
602
@patch ("sagemaker.serve.builder.model_builder.Model" )
532
603
@patch ("os.path.exists" )
604
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
605
+ @patch ("sagemaker.huggingface.llm_utils.json" )
533
606
def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model (
534
607
self ,
608
+ mock_urllib ,
609
+ mock_json ,
535
610
mock_path_exists ,
536
611
mock_sdk_model ,
537
612
mock_sageMakerEndpointMode ,
@@ -551,6 +626,10 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
551
626
else None
552
627
)
553
628
629
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
630
+ mock_hf_model_metadata_url = Mock ()
631
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
632
+
554
633
mock_detect_fw_version .return_value = "xgboost" , version
555
634
556
635
mock_prepare_for_torchserve .side_effect = (
@@ -635,8 +714,12 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
635
714
@patch ("sagemaker.serve.builder.model_builder.LocalContainerMode" )
636
715
@patch ("sagemaker.serve.builder.model_builder.Model" )
637
716
@patch ("os.path.exists" )
717
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
718
+ @patch ("sagemaker.huggingface.llm_utils.json" )
638
719
def test_build_happy_path_with_local_container_mode (
639
720
self ,
721
+ mock_urllib ,
722
+ mock_json ,
640
723
mock_path_exists ,
641
724
mock_sdk_model ,
642
725
mock_localContainerMode ,
@@ -651,6 +734,10 @@ def test_build_happy_path_with_local_container_mode(
651
734
lambda model_path : mock_native_model if model_path == MODEL_PATH else None
652
735
)
653
736
737
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
738
+ mock_hf_model_metadata_url = Mock ()
739
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
740
+
654
741
mock_detect_container .side_effect = (
655
742
lambda model , region , instance_type : mock_image_uri
656
743
if model == mock_native_model
@@ -729,8 +816,12 @@ def test_build_happy_path_with_local_container_mode(
729
816
@patch ("sagemaker.serve.builder.model_builder.LocalContainerMode" )
730
817
@patch ("sagemaker.serve.builder.model_builder.Model" )
731
818
@patch ("os.path.exists" )
819
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
820
+ @patch ("sagemaker.huggingface.llm_utils.json" )
732
821
def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mode (
733
822
self ,
823
+ mock_urllib ,
824
+ mock_json ,
734
825
mock_path_exists ,
735
826
mock_sdk_model ,
736
827
mock_localContainerMode ,
@@ -747,6 +838,10 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
747
838
lambda model_path : mock_native_model if model_path == MODEL_PATH else None
748
839
)
749
840
841
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
842
+ mock_hf_model_metadata_url = Mock ()
843
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
844
+
750
845
mock_detect_fw_version .return_value = framework , version
751
846
752
847
mock_detect_container .side_effect = (
@@ -870,8 +965,12 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
870
965
@patch ("sagemaker.serve.builder.model_builder.LocalContainerMode" )
871
966
@patch ("sagemaker.serve.builder.model_builder.Model" )
872
967
@patch ("os.path.exists" )
968
+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
969
+ @patch ("sagemaker.huggingface.llm_utils.json" )
873
970
def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_container (
874
971
self ,
972
+ mock_urllib ,
973
+ mock_json ,
875
974
mock_path_exists ,
876
975
mock_sdk_model ,
877
976
mock_localContainerMode ,
@@ -885,6 +984,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
885
984
# setup mocks
886
985
mock_detect_fw_version .return_value = framework , version
887
986
987
+ mock_json .load .return_value = MOCK_HF_MODEL_METADATA_JSON
988
+ mock_hf_model_metadata_url = Mock ()
989
+ mock_urllib .request .Request .side_effect = mock_hf_model_metadata_url
990
+
888
991
mock_detect_container .side_effect = (
889
992
lambda model , region , instance_type : mock_image_uri
890
993
if model == mock_fw_model
0 commit comments