12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
import time
15
- from typing import Optional
15
+ from typing import Optional , Set
16
16
from unittest import mock
17
17
import unittest
18
18
from inspect import signature
@@ -142,8 +142,9 @@ def test_non_prepacked(
142
142
predictor_cls = Predictor ,
143
143
role = execution_role ,
144
144
wait = True ,
145
+ use_compiled_model = False ,
145
146
enable_network_isolation = False ,
146
- name = "blahblahblah-9876" ,
147
+ model_name = "blahblahblah-9876" ,
147
148
endpoint_name = "blahblahblah-9876" ,
148
149
)
149
150
@@ -234,6 +235,7 @@ def test_prepacked(
234
235
predictor_cls = Predictor ,
235
236
role = execution_role ,
236
237
wait = True ,
238
+ use_compiled_model = False ,
237
239
enable_network_isolation = False ,
238
240
)
239
241
@@ -503,7 +505,8 @@ def evaluate_estimator_workflow_with_kwargs(
503
505
"predictor_cls" : Predictor ,
504
506
"role" : init_kwargs ["role" ],
505
507
"enable_network_isolation" : False ,
506
- "name" : "blahblahblah-1234" ,
508
+ "use_compiled_model" : False ,
509
+ "model_name" : "blahblahblah-1234" ,
507
510
"endpoint_name" : "blahblahblah-1234" ,
508
511
},
509
512
deploy_kwargs ,
@@ -512,6 +515,15 @@ def evaluate_estimator_workflow_with_kwargs(
512
515
mock_estimator_deploy .assert_called_once_with (** expected_deploy_kwargs )
513
516
514
517
def test_jumpstart_estimator_kwargs_match_parent_class (self ):
518
+
519
+ """If you add arguments to <Estimator constructor>, this test will fail.
520
+ Please add the new argument to the skip set below,
521
+ and cut a ticket sev-3 to JumpStart team: AWS > SageMaker > JumpStart"""
522
+
523
+ init_args_to_skip : Set [str ] = set (["kwargs" ])
524
+ fit_args_to_skip : Set [str ] = set ()
525
+ deploy_args_to_skip : Set [str ] = set (["kwargs" ])
526
+
515
527
parent_class_init = Estimator .__init__
516
528
parent_class_init_args = set (signature (parent_class_init ).parameters .keys ())
517
529
@@ -525,6 +537,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
525
537
"tolerate_vulnerable_model" ,
526
538
"tolerate_deprecated_model" ,
527
539
}
540
+ assert parent_class_init_args - js_class_init_args == init_args_to_skip
528
541
529
542
parent_class_fit = Estimator .fit
530
543
parent_class_fit_args = set (signature (parent_class_fit ).parameters .keys ())
@@ -533,6 +546,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
533
546
js_class_fit_args = set (signature (js_class_fit ).parameters .keys ())
534
547
535
548
assert js_class_fit_args - parent_class_fit_args == set ()
549
+ assert parent_class_fit_args - js_class_fit_args == fit_args_to_skip
536
550
537
551
model_class_init = Model .__init__
538
552
model_class_init_args = set (signature (model_class_init ).parameters .keys ())
@@ -546,7 +560,9 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
546
560
assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - {
547
561
"model_data" ,
548
562
"self" ,
563
+ "name" ,
549
564
}
565
+ assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip
550
566
551
567
@mock .patch ("sagemaker.jumpstart.estimator.get_init_kwargs" )
552
568
@mock .patch ("sagemaker.jumpstart.estimator.Estimator.__init__" )
@@ -842,9 +858,10 @@ def test_training_passes_role_to_deploy(
842
858
},
843
859
predictor_cls = Predictor ,
844
860
wait = True ,
861
+ use_compiled_model = False ,
845
862
role = mock_role ,
846
863
enable_network_isolation = False ,
847
- name = "blahblahblah-3456" ,
864
+ model_name = "blahblahblah-3456" ,
848
865
endpoint_name = "blahblahblah-3456" ,
849
866
)
850
867
@@ -916,9 +933,10 @@ def test_training_passes_session_to_deploy(
916
933
},
917
934
predictor_cls = Predictor ,
918
935
wait = True ,
936
+ use_compiled_model = False ,
919
937
role = mock_role ,
920
938
enable_network_isolation = False ,
921
- name = "blahblahblah-3456" ,
939
+ model_name = "blahblahblah-3456" ,
922
940
endpoint_name = "blahblahblah-3456" ,
923
941
)
924
942
0 commit comments