13
13
14
14
from __future__ import print_function , absolute_import
15
15
16
+ import copy
17
+
16
18
from mock import patch , Mock , MagicMock
17
19
import pytest
18
20
23
25
ModelConfig ,
24
26
ModelPredictedLabelConfig ,
25
27
SHAPConfig ,
28
+ PDPConfig ,
26
29
)
27
30
from sagemaker import image_uris , Processor
28
31
@@ -268,6 +271,14 @@ def test_shap_config():
268
271
assert expected_config == shap_config .get_explainability_config ()
269
272
270
273
274
+ def test_pdp_config ():
275
+ pdp_config = PDPConfig (features = ["f1" , "f2" ], grid_resolution = 20 )
276
+ expected_config = {
277
+ "pdp" : {"features" : ["f1" , "f2" ], "grid_resolution" : 20 , "top_k_features" : 10 }
278
+ }
279
+ assert expected_config == pdp_config .get_explainability_config ()
280
+
281
+
271
282
def test_invalid_shap_config ():
272
283
with pytest .raises (ValueError ) as error :
273
284
SHAPConfig (
@@ -367,13 +378,18 @@ def shap_config():
367
378
0.26124998927116394 ,
368
379
0.2824999988079071 ,
369
380
0.06875000149011612 ,
370
- ]
381
+ ],
371
382
],
372
383
num_samples = 100 ,
373
384
agg_method = "mean_sq" ,
374
385
)
375
386
376
387
388
+ @pytest .fixture (scope = "module" )
389
+ def pdp_config ():
390
+ return PDPConfig (features = ["F1" , "F2" ], grid_resolution = 20 )
391
+
392
+
377
393
@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
378
394
def test_pre_training_bias (
379
395
name_from_base ,
@@ -552,21 +568,30 @@ def test_run_on_s3_analysis_config_file(
552
568
)
553
569
554
570
555
- def _run_test_shap (
571
+ def _run_test_explain (
556
572
name_from_base ,
557
573
clarify_processor ,
558
574
clarify_processor_with_job_name_prefix ,
559
575
data_config ,
560
576
model_config ,
561
577
shap_config ,
578
+ pdp_config ,
562
579
model_scores ,
563
580
expected_predictor_config ,
564
581
):
565
582
with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
583
+ explanation_configs = None
584
+ if shap_config and pdp_config :
585
+ explanation_configs = [shap_config , pdp_config ]
586
+ elif shap_config :
587
+ explanation_configs = shap_config
588
+ elif pdp_config :
589
+ explanation_configs = pdp_config
590
+
566
591
clarify_processor .run_explainability (
567
592
data_config ,
568
593
model_config ,
569
- shap_config ,
594
+ explanation_configs ,
570
595
model_scores = model_scores ,
571
596
wait = True ,
572
597
job_name = "test" ,
@@ -581,23 +606,30 @@ def _run_test_shap(
581
606
"F3" ,
582
607
],
583
608
"label" : "Label" ,
584
- "methods" : {
585
- "shap" : {
586
- "baseline" : [
587
- [
588
- 0.26124998927116394 ,
589
- 0.2824999988079071 ,
590
- 0.06875000149011612 ,
591
- ]
592
- ],
593
- "num_samples" : 100 ,
594
- "agg_method" : "mean_sq" ,
595
- "use_logit" : False ,
596
- "save_local_shap_values" : True ,
597
- }
598
- },
599
609
"predictor" : expected_predictor_config ,
600
610
}
611
+ expected_explanation_configs = {}
612
+ if shap_config :
613
+ expected_explanation_configs ["shap" ] = {
614
+ "baseline" : [
615
+ [
616
+ 0.26124998927116394 ,
617
+ 0.2824999988079071 ,
618
+ 0.06875000149011612 ,
619
+ ]
620
+ ],
621
+ "num_samples" : 100 ,
622
+ "agg_method" : "mean_sq" ,
623
+ "use_logit" : False ,
624
+ "save_local_shap_values" : True ,
625
+ }
626
+ if pdp_config :
627
+ expected_explanation_configs ["pdp" ] = {
628
+ "features" : ["F1" , "F2" ],
629
+ "grid_resolution" : 20 ,
630
+ "top_k_features" : 10 ,
631
+ }
632
+ expected_analysis_config ["methods" ] = expected_explanation_configs
601
633
mock_method .assert_called_with (
602
634
data_config ,
603
635
expected_analysis_config ,
@@ -610,7 +642,7 @@ def _run_test_shap(
610
642
clarify_processor_with_job_name_prefix .run_explainability (
611
643
data_config ,
612
644
model_config ,
613
- shap_config ,
645
+ explanation_configs ,
614
646
model_scores = model_scores ,
615
647
wait = True ,
616
648
experiment_config = {"ExperimentName" : "AnExperiment" },
@@ -627,6 +659,34 @@ def _run_test_shap(
627
659
)
628
660
629
661
662
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
663
+ def test_pdp (
664
+ name_from_base ,
665
+ clarify_processor ,
666
+ clarify_processor_with_job_name_prefix ,
667
+ data_config ,
668
+ model_config ,
669
+ shap_config ,
670
+ pdp_config ,
671
+ ):
672
+ expected_predictor_config = {
673
+ "model_name" : "xgboost-model" ,
674
+ "instance_type" : "ml.c5.xlarge" ,
675
+ "initial_instance_count" : 1 ,
676
+ }
677
+ _run_test_explain (
678
+ name_from_base ,
679
+ clarify_processor ,
680
+ clarify_processor_with_job_name_prefix ,
681
+ data_config ,
682
+ model_config ,
683
+ None ,
684
+ pdp_config ,
685
+ None ,
686
+ expected_predictor_config ,
687
+ )
688
+
689
+
630
690
@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
631
691
def test_shap (
632
692
name_from_base ,
@@ -641,18 +701,78 @@ def test_shap(
641
701
"instance_type" : "ml.c5.xlarge" ,
642
702
"initial_instance_count" : 1 ,
643
703
}
644
- _run_test_shap (
704
+ _run_test_explain (
645
705
name_from_base ,
646
706
clarify_processor ,
647
707
clarify_processor_with_job_name_prefix ,
648
708
data_config ,
649
709
model_config ,
650
710
shap_config ,
651
711
None ,
712
+ None ,
652
713
expected_predictor_config ,
653
714
)
654
715
655
716
717
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
718
+ def test_explainability_with_invalid_config (
719
+ name_from_base ,
720
+ clarify_processor ,
721
+ clarify_processor_with_job_name_prefix ,
722
+ data_config ,
723
+ model_config ,
724
+ ):
725
+ expected_predictor_config = {
726
+ "model_name" : "xgboost-model" ,
727
+ "instance_type" : "ml.c5.xlarge" ,
728
+ "initial_instance_count" : 1 ,
729
+ }
730
+ with pytest .raises (
731
+ AttributeError , match = "'NoneType' object has no attribute 'get_explainability_config'"
732
+ ):
733
+ _run_test_explain (
734
+ name_from_base ,
735
+ clarify_processor ,
736
+ clarify_processor_with_job_name_prefix ,
737
+ data_config ,
738
+ model_config ,
739
+ None ,
740
+ None ,
741
+ None ,
742
+ expected_predictor_config ,
743
+ )
744
+
745
+
746
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
747
+ def test_explainability_with_multiple_shap_config (
748
+ name_from_base ,
749
+ clarify_processor ,
750
+ clarify_processor_with_job_name_prefix ,
751
+ data_config ,
752
+ model_config ,
753
+ shap_config ,
754
+ ):
755
+ expected_predictor_config = {
756
+ "model_name" : "xgboost-model" ,
757
+ "instance_type" : "ml.c5.xlarge" ,
758
+ "initial_instance_count" : 1 ,
759
+ }
760
+ with pytest .raises (ValueError , match = "Duplicate explainability configs are provided" ):
761
+ second_shap_config = copy .deepcopy (shap_config )
762
+ second_shap_config .shap_config ["num_samples" ] = 200
763
+ _run_test_explain (
764
+ name_from_base ,
765
+ clarify_processor ,
766
+ clarify_processor_with_job_name_prefix ,
767
+ data_config ,
768
+ model_config ,
769
+ [shap_config , second_shap_config ],
770
+ None ,
771
+ None ,
772
+ expected_predictor_config ,
773
+ )
774
+
775
+
656
776
@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
657
777
def test_shap_with_predicted_label (
658
778
name_from_base ,
@@ -661,6 +781,7 @@ def test_shap_with_predicted_label(
661
781
data_config ,
662
782
model_config ,
663
783
shap_config ,
784
+ pdp_config ,
664
785
):
665
786
probability = "pr"
666
787
label_headers = ["success" ]
@@ -675,13 +796,14 @@ def test_shap_with_predicted_label(
675
796
"probability" : probability ,
676
797
"label_headers" : label_headers ,
677
798
}
678
- _run_test_shap (
799
+ _run_test_explain (
679
800
name_from_base ,
680
801
clarify_processor ,
681
802
clarify_processor_with_job_name_prefix ,
682
803
data_config ,
683
804
model_config ,
684
805
shap_config ,
806
+ pdp_config ,
685
807
model_scores ,
686
808
expected_predictor_config ,
687
809
)
0 commit comments