20
20
import os
21
21
import tempfile
22
22
import re
23
+ from typing import List
24
+
23
25
from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
24
26
from sagemaker import image_uris , s3 , utils
25
27
@@ -54,7 +56,11 @@ def __init__(
54
56
"ShardedByS3Key".
55
57
s3_compression_type (str): Valid options are "None" or "Gzip".
56
58
"""
57
- if dataset_type not in ["text/csv" , "application/jsonlines" , "application/x-parquet" ]:
59
+ if dataset_type not in [
60
+ "text/csv" ,
61
+ "application/jsonlines" ,
62
+ "application/x-parquet" ,
63
+ ]:
58
64
raise ValueError (
59
65
f"Invalid dataset_type '{ dataset_type } '."
60
66
f" Please check the API documentation for the supported dataset types."
@@ -292,7 +298,30 @@ class ExplainabilityConfig(ABC):
292
298
@abstractmethod
293
299
def get_explainability_config (self ):
294
300
"""Returns config."""
295
- return None
301
+
302
+
303
+ class PDPConfig (ExplainabilityConfig ):
304
+ """Config class for Partial Dependence Plots (PDP)"""
305
+
306
+ def __init__ (self , features = None , grid_resolution = None ):
307
+ """Initializes config for PDP.
308
+
309
+ Args:
310
+ features (None or list): List of features names or indices for which partial dependence
311
+ plots must be computed and plotted.
312
+ grid_resolution (int): In case of numerical features, this number represents that
313
+ number of buckets that range of values must be divided into. This decides the
314
+ granularity of the grid in which the PDP are plotted.
315
+ """
316
+ self .pdp_config = {}
317
+ if features is not None :
318
+ self .pdp_config ["features" ] = features
319
+ if grid_resolution is not None :
320
+ self .pdp_config ["grid_resolution" ] = grid_resolution
321
+
322
+ def get_explainability_config (self ):
323
+ """Returns config."""
324
+ return {"pdp" : copy .deepcopy (self .pdp_config )}
296
325
297
326
298
327
class SHAPConfig (ExplainabilityConfig ):
@@ -466,7 +495,10 @@ def _run(
466
495
will be unassociated.
467
496
* `TrialComponentDisplayName` is used for display in Studio.
468
497
"""
469
- analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
498
+ analysis_config ["methods" ]["report" ] = {
499
+ "name" : "report" ,
500
+ "title" : "Analysis Report" ,
501
+ }
470
502
with tempfile .TemporaryDirectory () as tmpdirname :
471
503
analysis_config_file = os .path .join (tmpdirname , "analysis_config.json" )
472
504
with open (analysis_config_file , "w" ) as f :
@@ -568,7 +600,15 @@ def run_pre_training_bias(
568
600
job_name = utils .name_from_base (self .job_name_prefix )
569
601
else :
570
602
job_name = utils .name_from_base ("Clarify-Pretraining-Bias" )
571
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
603
+ self ._run (
604
+ data_config ,
605
+ analysis_config ,
606
+ wait ,
607
+ logs ,
608
+ job_name ,
609
+ kms_key ,
610
+ experiment_config ,
611
+ )
572
612
573
613
def run_post_training_bias (
574
614
self ,
@@ -646,7 +686,15 @@ def run_post_training_bias(
646
686
job_name = utils .name_from_base (self .job_name_prefix )
647
687
else :
648
688
job_name = utils .name_from_base ("Clarify-Posttraining-Bias" )
649
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
689
+ self ._run (
690
+ data_config ,
691
+ analysis_config ,
692
+ wait ,
693
+ logs ,
694
+ job_name ,
695
+ kms_key ,
696
+ experiment_config ,
697
+ )
650
698
651
699
def run_bias (
652
700
self ,
@@ -741,7 +789,15 @@ def run_bias(
741
789
job_name = utils .name_from_base (self .job_name_prefix )
742
790
else :
743
791
job_name = utils .name_from_base ("Clarify-Bias" )
744
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
792
+ self ._run (
793
+ data_config ,
794
+ analysis_config ,
795
+ wait ,
796
+ logs ,
797
+ job_name ,
798
+ kms_key ,
799
+ experiment_config ,
800
+ )
745
801
746
802
def run_explainability (
747
803
self ,
@@ -772,7 +828,8 @@ def run_explainability(
772
828
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
773
829
endpoint to be created.
774
830
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
775
- specific explainability method. Currently, only SHAP is supported.
831
+ specific explainability method or a list of ExplainabilityConfig objects. Currently,
832
+ SHAP and PDP are the two methods supported.
776
833
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
777
834
model output for the predicted scores to be explained. This is not required if the
778
835
model output is a single score. Alternatively, an instance of
@@ -781,7 +838,7 @@ def run_explainability(
781
838
logs (bool): Whether to show the logs produced by the job.
782
839
Only meaningful when ``wait`` is True (default: True).
783
840
job_name (str): Processing job name. When ``job_name`` is not specified, if
784
- `` job_name_prefix` ` in :class:`SageMakerClarifyProcessor` specified, the job name
841
+ `job_name_prefix` in :class:`SageMakerClarifyProcessor` specified, the job name
785
842
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
786
843
"Clarify-Explainability" as prefix.
787
844
kms_key (str): The ARN of the KMS key that is used to encrypt the
@@ -801,19 +858,44 @@ def run_explainability(
801
858
analysis_config = data_config .get_config ()
802
859
predictor_config = model_config .get_predictor_config ()
803
860
if isinstance (model_scores , ModelPredictedLabelConfig ):
804
- probability_threshold , predicted_label_config = model_scores .get_predictor_config ()
861
+ (
862
+ probability_threshold ,
863
+ predicted_label_config ,
864
+ ) = model_scores .get_predictor_config ()
805
865
_set (probability_threshold , "probability_threshold" , analysis_config )
806
866
predictor_config .update (predicted_label_config )
807
867
else :
808
868
_set (model_scores , "label" , predictor_config )
809
- analysis_config ["methods" ] = explainability_config .get_explainability_config ()
869
+
870
+ explainability_methods = {}
871
+ if isinstance (explainability_config , List ): # pylint: disable=W1116
872
+ for config in explainability_config :
873
+ if not isinstance (config , ExplainabilityConfig ):
874
+ raise ValueError (
875
+ f"Invalid input: Excepted ExplainabilityConfig, got { type (config )} instead"
876
+ )
877
+ explain_config = config .get_explainability_config ()
878
+ explainability_methods [list (explain_config .keys ())[0 ]] = explain_config [
879
+ list (explain_config .keys ())[0 ]
880
+ ]
881
+ elif isinstance (explainability_config , ExplainabilityConfig ):
882
+ explainability_methods = explainability_config .get_explainability_config ()
883
+ analysis_config ["methods" ] = explainability_methods
810
884
analysis_config ["predictor" ] = predictor_config
811
885
if job_name is None :
812
886
if self .job_name_prefix :
813
887
job_name = utils .name_from_base (self .job_name_prefix )
814
888
else :
815
889
job_name = utils .name_from_base ("Clarify-Explainability" )
816
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
890
+ self ._run (
891
+ data_config ,
892
+ analysis_config ,
893
+ wait ,
894
+ logs ,
895
+ job_name ,
896
+ kms_key ,
897
+ experiment_config ,
898
+ )
817
899
818
900
819
901
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments