20
20
import os
21
21
import tempfile
22
22
import re
23
+
23
24
from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
24
25
from sagemaker import image_uris , s3 , utils
25
26
@@ -292,7 +293,30 @@ class ExplainabilityConfig(ABC):
292
293
@abstractmethod
293
294
def get_explainability_config (self ):
294
295
"""Returns config."""
295
- return None
296
+
297
+
298
+ class PDPConfig (ExplainabilityConfig ):
299
+ """Config class for Partial Dependence Plots (PDP)"""
300
+
301
+ def __init__ (self , features = None , grid_resolution = None ):
302
+ """Initializes config for PDP.
303
+
304
+ Args:
305
+ features (None or list): List of features names or indices for which partial dependence
306
+ plots must be computed and plotted.
307
+ grid_resolution (int): In case of numerical features, this number represents that
308
+ number of buckets that range of values must be divided into. This decides the
309
+ granularity of the grid in which the PDP are plotted.
310
+ """
311
+ self .pdp_config = {}
312
+ if features is not None :
313
+ self .pdp_config ["features" ] = features
314
+ if grid_resolution is not None :
315
+ self .pdp_config ["grid_resolution" ] = grid_resolution
316
+
317
+ def get_explainability_config (self ):
318
+ """Returns config."""
319
+ return copy .deepcopy ({"pdp" : self .pdp_config })
296
320
297
321
298
322
class SHAPConfig (ExplainabilityConfig ):
@@ -466,7 +490,10 @@ def _run(
466
490
will be unassociated.
467
491
* `TrialComponentDisplayName` is used for display in Studio.
468
492
"""
469
- analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
493
+ analysis_config ["methods" ]["report" ] = {
494
+ "name" : "report" ,
495
+ "title" : "Analysis Report" ,
496
+ }
470
497
with tempfile .TemporaryDirectory () as tmpdirname :
471
498
analysis_config_file = os .path .join (tmpdirname , "analysis_config.json" )
472
499
with open (analysis_config_file , "w" ) as f :
@@ -568,7 +595,15 @@ def run_pre_training_bias(
568
595
job_name = utils .name_from_base (self .job_name_prefix )
569
596
else :
570
597
job_name = utils .name_from_base ("Clarify-Pretraining-Bias" )
571
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
598
+ self ._run (
599
+ data_config ,
600
+ analysis_config ,
601
+ wait ,
602
+ logs ,
603
+ job_name ,
604
+ kms_key ,
605
+ experiment_config ,
606
+ )
572
607
573
608
def run_post_training_bias (
574
609
self ,
@@ -646,7 +681,15 @@ def run_post_training_bias(
646
681
job_name = utils .name_from_base (self .job_name_prefix )
647
682
else :
648
683
job_name = utils .name_from_base ("Clarify-Posttraining-Bias" )
649
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
684
+ self ._run (
685
+ data_config ,
686
+ analysis_config ,
687
+ wait ,
688
+ logs ,
689
+ job_name ,
690
+ kms_key ,
691
+ experiment_config ,
692
+ )
650
693
651
694
def run_bias (
652
695
self ,
@@ -741,7 +784,15 @@ def run_bias(
741
784
job_name = utils .name_from_base (self .job_name_prefix )
742
785
else :
743
786
job_name = utils .name_from_base ("Clarify-Bias" )
744
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
787
+ self ._run (
788
+ data_config ,
789
+ analysis_config ,
790
+ wait ,
791
+ logs ,
792
+ job_name ,
793
+ kms_key ,
794
+ experiment_config ,
795
+ )
745
796
746
797
def run_explainability (
747
798
self ,
@@ -771,8 +822,9 @@ def run_explainability(
771
822
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
772
823
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
773
824
endpoint to be created.
774
- explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
775
- specific explainability method. Currently, only SHAP is supported.
825
+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
826
+ Config of the specific explainability method or a list of ExplainabilityConfig
827
+ objects. Currently, SHAP and PDP are the two methods supported.
776
828
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
777
829
model output for the predicted scores to be explained. This is not required if the
778
830
model output is a single score. Alternatively, an instance of
@@ -781,7 +833,7 @@ def run_explainability(
781
833
logs (bool): Whether to show the logs produced by the job.
782
834
Only meaningful when ``wait`` is True (default: True).
783
835
job_name (str): Processing job name. When ``job_name`` is not specified, if
784
- `` job_name_prefix` ` in :class:`SageMakerClarifyProcessor` specified, the job name
836
+ `job_name_prefix` in :class:`SageMakerClarifyProcessor` specified, the job name
785
837
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
786
838
"Clarify-Explainability" as prefix.
787
839
kms_key (str): The ARN of the KMS key that is used to encrypt the
@@ -801,19 +853,41 @@ def run_explainability(
801
853
analysis_config = data_config .get_config ()
802
854
predictor_config = model_config .get_predictor_config ()
803
855
if isinstance (model_scores , ModelPredictedLabelConfig ):
804
- probability_threshold , predicted_label_config = model_scores .get_predictor_config ()
856
+ (
857
+ probability_threshold ,
858
+ predicted_label_config ,
859
+ ) = model_scores .get_predictor_config ()
805
860
_set (probability_threshold , "probability_threshold" , analysis_config )
806
861
predictor_config .update (predicted_label_config )
807
862
else :
808
863
_set (model_scores , "label" , predictor_config )
809
- analysis_config ["methods" ] = explainability_config .get_explainability_config ()
864
+
865
+ explainability_methods = {}
866
+ if isinstance (explainability_config , list ):
867
+ assert (
868
+ len (explainability_config ) > 0
869
+ ), "Invalid Input: Expected list of ExplainabilityConfig but found empty list instead"
870
+ for config in explainability_config :
871
+ explain_config = config .get_explainability_config ()
872
+ explainability_methods .update (explain_config )
873
+ else :
874
+ explainability_methods = explainability_config .get_explainability_config ()
875
+ analysis_config ["methods" ] = explainability_methods
810
876
analysis_config ["predictor" ] = predictor_config
811
877
if job_name is None :
812
878
if self .job_name_prefix :
813
879
job_name = utils .name_from_base (self .job_name_prefix )
814
880
else :
815
881
job_name = utils .name_from_base ("Clarify-Explainability" )
816
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
882
+ self ._run (
883
+ data_config ,
884
+ analysis_config ,
885
+ wait ,
886
+ logs ,
887
+ job_name ,
888
+ kms_key ,
889
+ experiment_config ,
890
+ )
817
891
818
892
819
893
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments