20
20
import os
21
21
import tempfile
22
22
import re
23
+
23
24
from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
24
25
from sagemaker import image_uris , s3 , utils
25
26
@@ -297,7 +298,30 @@ class ExplainabilityConfig(ABC):
297
298
@abstractmethod
298
299
def get_explainability_config (self ):
299
300
"""Returns config."""
300
- return None
301
+
302
+
303
+ class PDPConfig (ExplainabilityConfig ):
304
+ """Config class for Partial Dependence Plots (PDP)"""
305
+
306
+ def __init__ (self , features = None , grid_resolution = None ):
307
+ """Initializes config for PDP.
308
+
309
+ Args:
310
+ features (None or list): List of features names or indices for which partial dependence
311
+ plots must be computed and plotted.
312
+ grid_resolution (int): In case of numerical features, this number represents that
313
+ number of buckets that range of values must be divided into. This decides the
314
+ granularity of the grid in which the PDP are plotted.
315
+ """
316
+ self .pdp_config = {}
317
+ if features is not None :
318
+ self .pdp_config ["features" ] = features
319
+ if grid_resolution is not None :
320
+ self .pdp_config ["grid_resolution" ] = grid_resolution
321
+
322
+ def get_explainability_config (self ):
323
+ """Returns config."""
324
+ return copy .deepcopy ({"pdp" : self .pdp_config })
301
325
302
326
303
327
class SHAPConfig (ExplainabilityConfig ):
@@ -471,7 +495,10 @@ def _run(
471
495
will be unassociated.
472
496
* `TrialComponentDisplayName` is used for display in Studio.
473
497
"""
474
- analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
498
+ analysis_config ["methods" ]["report" ] = {
499
+ "name" : "report" ,
500
+ "title" : "Analysis Report" ,
501
+ }
475
502
with tempfile .TemporaryDirectory () as tmpdirname :
476
503
analysis_config_file = os .path .join (tmpdirname , "analysis_config.json" )
477
504
with open (analysis_config_file , "w" ) as f :
@@ -573,7 +600,15 @@ def run_pre_training_bias(
573
600
job_name = utils .name_from_base (self .job_name_prefix )
574
601
else :
575
602
job_name = utils .name_from_base ("Clarify-Pretraining-Bias" )
576
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
603
+ self ._run (
604
+ data_config ,
605
+ analysis_config ,
606
+ wait ,
607
+ logs ,
608
+ job_name ,
609
+ kms_key ,
610
+ experiment_config ,
611
+ )
577
612
578
613
def run_post_training_bias (
579
614
self ,
@@ -651,7 +686,15 @@ def run_post_training_bias(
651
686
job_name = utils .name_from_base (self .job_name_prefix )
652
687
else :
653
688
job_name = utils .name_from_base ("Clarify-Posttraining-Bias" )
654
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
689
+ self ._run (
690
+ data_config ,
691
+ analysis_config ,
692
+ wait ,
693
+ logs ,
694
+ job_name ,
695
+ kms_key ,
696
+ experiment_config ,
697
+ )
655
698
656
699
def run_bias (
657
700
self ,
@@ -746,7 +789,15 @@ def run_bias(
746
789
job_name = utils .name_from_base (self .job_name_prefix )
747
790
else :
748
791
job_name = utils .name_from_base ("Clarify-Bias" )
749
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
792
+ self ._run (
793
+ data_config ,
794
+ analysis_config ,
795
+ wait ,
796
+ logs ,
797
+ job_name ,
798
+ kms_key ,
799
+ experiment_config ,
800
+ )
750
801
751
802
def run_explainability (
752
803
self ,
@@ -776,8 +827,9 @@ def run_explainability(
776
827
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
777
828
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
778
829
endpoint to be created.
779
- explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
780
- specific explainability method. Currently, only SHAP is supported.
830
+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
831
+ Config of the specific explainability method or a list of ExplainabilityConfig
832
+ objects. Currently, SHAP and PDP are the two methods supported.
781
833
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
782
834
model output for the predicted scores to be explained. This is not required if the
783
835
model output is a single score. Alternatively, an instance of
@@ -786,7 +838,7 @@ def run_explainability(
786
838
logs (bool): Whether to show the logs produced by the job.
787
839
Only meaningful when ``wait`` is True (default: True).
788
840
job_name (str): Processing job name. When ``job_name`` is not specified, if
789
- `` job_name_prefix` ` in :class:`SageMakerClarifyProcessor` specified, the job name
841
+ `job_name_prefix` in :class:`SageMakerClarifyProcessor` specified, the job name
790
842
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
791
843
"Clarify-Explainability" as prefix.
792
844
kms_key (str): The ARN of the KMS key that is used to encrypt the
@@ -806,19 +858,44 @@ def run_explainability(
806
858
analysis_config = data_config .get_config ()
807
859
predictor_config = model_config .get_predictor_config ()
808
860
if isinstance (model_scores , ModelPredictedLabelConfig ):
809
- probability_threshold , predicted_label_config = model_scores .get_predictor_config ()
861
+ (
862
+ probability_threshold ,
863
+ predicted_label_config ,
864
+ ) = model_scores .get_predictor_config ()
810
865
_set (probability_threshold , "probability_threshold" , analysis_config )
811
866
predictor_config .update (predicted_label_config )
812
867
else :
813
868
_set (model_scores , "label" , predictor_config )
814
- analysis_config ["methods" ] = explainability_config .get_explainability_config ()
869
+
870
+ explainability_methods = {}
871
+ if isinstance (explainability_config , list ):
872
+ assert (
873
+ len (explainability_config ) > 0
874
+ ), "Please provide at least one explaianbility config."
875
+ for config in explainability_config :
876
+ explain_config = config .get_explainability_config ()
877
+ explainability_methods .update (explain_config )
878
+ assert len (explainability_methods .keys ()) == len (
879
+ explainability_config
880
+ ), "There are duplicate explainability configs"
881
+ else :
882
+ explainability_methods = explainability_config .get_explainability_config ()
883
+ analysis_config ["methods" ] = explainability_methods
815
884
analysis_config ["predictor" ] = predictor_config
816
885
if job_name is None :
817
886
if self .job_name_prefix :
818
887
job_name = utils .name_from_base (self .job_name_prefix )
819
888
else :
820
889
job_name = utils .name_from_base ("Clarify-Explainability" )
821
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
890
+ self ._run (
891
+ data_config ,
892
+ analysis_config ,
893
+ wait ,
894
+ logs ,
895
+ job_name ,
896
+ kms_key ,
897
+ experiment_config ,
898
+ )
822
899
823
900
824
901
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments