13
13
14
14
from __future__ import print_function , absolute_import
15
15
16
+ import copy
17
+
16
18
from mock import patch , Mock , MagicMock
17
19
import pytest
18
20
23
25
ModelConfig ,
24
26
ModelPredictedLabelConfig ,
25
27
SHAPConfig ,
28
+ PDPConfig ,
26
29
)
27
30
from sagemaker import image_uris , Processor
28
31
@@ -304,6 +307,14 @@ def test_shap_config_no_parameters():
304
307
assert expected_config == shap_config .get_explainability_config ()
305
308
306
309
310
+ def test_pdp_config ():
311
+ pdp_config = PDPConfig (features = ["f1" , "f2" ], grid_resolution = 20 )
312
+ expected_config = {
313
+ "pdp" : {"features" : ["f1" , "f2" ], "grid_resolution" : 20 , "top_k_features" : 10 }
314
+ }
315
+ assert expected_config == pdp_config .get_explainability_config ()
316
+
317
+
307
318
def test_invalid_shap_config ():
308
319
with pytest .raises (ValueError ) as error :
309
320
SHAPConfig (
@@ -409,13 +420,18 @@ def shap_config():
409
420
0.26124998927116394 ,
410
421
0.2824999988079071 ,
411
422
0.06875000149011612 ,
412
- ]
423
+ ],
413
424
],
414
425
num_samples = 100 ,
415
426
agg_method = "mean_sq" ,
416
427
)
417
428
418
429
430
+ @pytest .fixture (scope = "module" )
431
+ def pdp_config ():
432
+ return PDPConfig (features = ["F1" , "F2" ], grid_resolution = 20 )
433
+
434
+
419
435
@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
420
436
def test_pre_training_bias (
421
437
name_from_base ,
@@ -594,21 +610,30 @@ def test_run_on_s3_analysis_config_file(
594
610
)
595
611
596
612
597
- def _run_test_shap (
613
+ def _run_test_explain (
598
614
name_from_base ,
599
615
clarify_processor ,
600
616
clarify_processor_with_job_name_prefix ,
601
617
data_config ,
602
618
model_config ,
603
619
shap_config ,
620
+ pdp_config ,
604
621
model_scores ,
605
622
expected_predictor_config ,
606
623
):
607
624
with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
625
+ explanation_configs = None
626
+ if shap_config and pdp_config :
627
+ explanation_configs = [shap_config , pdp_config ]
628
+ elif shap_config :
629
+ explanation_configs = shap_config
630
+ elif pdp_config :
631
+ explanation_configs = pdp_config
632
+
608
633
clarify_processor .run_explainability (
609
634
data_config ,
610
635
model_config ,
611
- shap_config ,
636
+ explanation_configs ,
612
637
model_scores = model_scores ,
613
638
wait = True ,
614
639
job_name = "test" ,
@@ -623,23 +648,30 @@ def _run_test_shap(
623
648
"F3" ,
624
649
],
625
650
"label" : "Label" ,
626
- "methods" : {
627
- "shap" : {
628
- "baseline" : [
629
- [
630
- 0.26124998927116394 ,
631
- 0.2824999988079071 ,
632
- 0.06875000149011612 ,
633
- ]
634
- ],
635
- "num_samples" : 100 ,
636
- "agg_method" : "mean_sq" ,
637
- "use_logit" : False ,
638
- "save_local_shap_values" : True ,
639
- }
640
- },
641
651
"predictor" : expected_predictor_config ,
642
652
}
653
+ expected_explanation_configs = {}
654
+ if shap_config :
655
+ expected_explanation_configs ["shap" ] = {
656
+ "baseline" : [
657
+ [
658
+ 0.26124998927116394 ,
659
+ 0.2824999988079071 ,
660
+ 0.06875000149011612 ,
661
+ ]
662
+ ],
663
+ "num_samples" : 100 ,
664
+ "agg_method" : "mean_sq" ,
665
+ "use_logit" : False ,
666
+ "save_local_shap_values" : True ,
667
+ }
668
+ if pdp_config :
669
+ expected_explanation_configs ["pdp" ] = {
670
+ "features" : ["F1" , "F2" ],
671
+ "grid_resolution" : 20 ,
672
+ "top_k_features" : 10 ,
673
+ }
674
+ expected_analysis_config ["methods" ] = expected_explanation_configs
643
675
mock_method .assert_called_with (
644
676
data_config ,
645
677
expected_analysis_config ,
@@ -652,7 +684,7 @@ def _run_test_shap(
652
684
clarify_processor_with_job_name_prefix .run_explainability (
653
685
data_config ,
654
686
model_config ,
655
- shap_config ,
687
+ explanation_configs ,
656
688
model_scores = model_scores ,
657
689
wait = True ,
658
690
experiment_config = {"ExperimentName" : "AnExperiment" },
@@ -669,6 +701,34 @@ def _run_test_shap(
669
701
)
670
702
671
703
704
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
705
+ def test_pdp (
706
+ name_from_base ,
707
+ clarify_processor ,
708
+ clarify_processor_with_job_name_prefix ,
709
+ data_config ,
710
+ model_config ,
711
+ shap_config ,
712
+ pdp_config ,
713
+ ):
714
+ expected_predictor_config = {
715
+ "model_name" : "xgboost-model" ,
716
+ "instance_type" : "ml.c5.xlarge" ,
717
+ "initial_instance_count" : 1 ,
718
+ }
719
+ _run_test_explain (
720
+ name_from_base ,
721
+ clarify_processor ,
722
+ clarify_processor_with_job_name_prefix ,
723
+ data_config ,
724
+ model_config ,
725
+ None ,
726
+ pdp_config ,
727
+ None ,
728
+ expected_predictor_config ,
729
+ )
730
+
731
+
672
732
@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
673
733
def test_shap (
674
734
name_from_base ,
@@ -683,18 +743,78 @@ def test_shap(
683
743
"instance_type" : "ml.c5.xlarge" ,
684
744
"initial_instance_count" : 1 ,
685
745
}
686
- _run_test_shap (
746
+ _run_test_explain (
687
747
name_from_base ,
688
748
clarify_processor ,
689
749
clarify_processor_with_job_name_prefix ,
690
750
data_config ,
691
751
model_config ,
692
752
shap_config ,
693
753
None ,
754
+ None ,
694
755
expected_predictor_config ,
695
756
)
696
757
697
758
759
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
760
+ def test_explainability_with_invalid_config (
761
+ name_from_base ,
762
+ clarify_processor ,
763
+ clarify_processor_with_job_name_prefix ,
764
+ data_config ,
765
+ model_config ,
766
+ ):
767
+ expected_predictor_config = {
768
+ "model_name" : "xgboost-model" ,
769
+ "instance_type" : "ml.c5.xlarge" ,
770
+ "initial_instance_count" : 1 ,
771
+ }
772
+ with pytest .raises (
773
+ AttributeError , match = "'NoneType' object has no attribute 'get_explainability_config'"
774
+ ):
775
+ _run_test_explain (
776
+ name_from_base ,
777
+ clarify_processor ,
778
+ clarify_processor_with_job_name_prefix ,
779
+ data_config ,
780
+ model_config ,
781
+ None ,
782
+ None ,
783
+ None ,
784
+ expected_predictor_config ,
785
+ )
786
+
787
+
788
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
789
+ def test_explainability_with_multiple_shap_config (
790
+ name_from_base ,
791
+ clarify_processor ,
792
+ clarify_processor_with_job_name_prefix ,
793
+ data_config ,
794
+ model_config ,
795
+ shap_config ,
796
+ ):
797
+ expected_predictor_config = {
798
+ "model_name" : "xgboost-model" ,
799
+ "instance_type" : "ml.c5.xlarge" ,
800
+ "initial_instance_count" : 1 ,
801
+ }
802
+ with pytest .raises (ValueError , match = "Duplicate explainability configs are provided" ):
803
+ second_shap_config = copy .deepcopy (shap_config )
804
+ second_shap_config .shap_config ["num_samples" ] = 200
805
+ _run_test_explain (
806
+ name_from_base ,
807
+ clarify_processor ,
808
+ clarify_processor_with_job_name_prefix ,
809
+ data_config ,
810
+ model_config ,
811
+ [shap_config , second_shap_config ],
812
+ None ,
813
+ None ,
814
+ expected_predictor_config ,
815
+ )
816
+
817
+
698
818
@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
699
819
def test_shap_with_predicted_label (
700
820
name_from_base ,
@@ -703,6 +823,7 @@ def test_shap_with_predicted_label(
703
823
data_config ,
704
824
model_config ,
705
825
shap_config ,
826
+ pdp_config ,
706
827
):
707
828
probability = "pr"
708
829
label_headers = ["success" ]
@@ -717,13 +838,14 @@ def test_shap_with_predicted_label(
717
838
"probability" : probability ,
718
839
"label_headers" : label_headers ,
719
840
}
720
- _run_test_shap (
841
+ _run_test_explain (
721
842
name_from_base ,
722
843
clarify_processor ,
723
844
clarify_processor_with_job_name_prefix ,
724
845
data_config ,
725
846
model_config ,
726
847
shap_config ,
848
+ pdp_config ,
727
849
model_scores ,
728
850
expected_predictor_config ,
729
851
)
0 commit comments