@@ -31,6 +31,7 @@ def __init__(
31
31
self ,
32
32
s3_data_input_path ,
33
33
s3_output_path ,
34
+ s3_analysis_config_output_path = None ,
34
35
label = None ,
35
36
headers = None ,
36
37
features = None ,
@@ -43,6 +44,9 @@ def __init__(
43
44
Args:
44
45
s3_data_input_path (str): Dataset S3 prefix/object URI.
45
46
s3_output_path (str): S3 prefix to store the output.
47
+ s3_analysis_config_output_path (str): S3 prefix to store the analysis_config output
48
+ If this field is None, then the s3_output_path will be used
49
+ to store the analysis_config output
46
50
label (str): Target attribute of the model required by bias metrics (optional for SHAP)
47
51
Specified as column name or index for CSV dataset, or as JSONPath for JSONLines.
48
52
headers (list[str]): A list of column names in the input dataset.
@@ -61,6 +65,7 @@ def __init__(
61
65
)
62
66
self .s3_data_input_path = s3_data_input_path
63
67
self .s3_output_path = s3_output_path
68
+ self .s3_analysis_config_output_path = s3_analysis_config_output_path
64
69
self .s3_data_distribution_type = s3_data_distribution_type
65
70
self .s3_compression_type = s3_compression_type
66
71
self .label = label
@@ -300,12 +305,13 @@ class SHAPConfig(ExplainabilityConfig):
300
305
301
306
def __init__ (
302
307
self ,
303
- baseline ,
304
- num_samples ,
305
- agg_method ,
308
+ baseline = None ,
309
+ num_samples = None ,
310
+ agg_method = None ,
306
311
use_logit = False ,
307
312
save_local_shap_values = True ,
308
313
seed = None ,
314
+ num_clusters = None ,
309
315
):
310
316
"""Initializes config for SHAP.
311
317
@@ -315,34 +321,49 @@ def __init__(
315
321
be the same as the dataset format. Each row should contain only the feature
316
322
columns/values and omit the label column/values. If None a baseline will be
317
323
calculated automatically by using K-means or K-prototypes in the input dataset.
318
- num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
324
+ num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
319
325
This number determines the size of the generated synthetic dataset to compute the
320
- SHAP values.
321
- agg_method (str): Aggregation method for global SHAP values. Valid values are
326
+ SHAP values. If not provided then Clarify job will choose a proper value according
327
+ to the count of features.
328
+ agg_method (None or str): Aggregation method for global SHAP values. Valid values are
322
329
"mean_abs" (mean of absolute SHAP values for all instances),
323
330
"median" (median of SHAP values for all instances) and
324
331
"mean_sq" (mean of squared SHAP values for all instances).
332
+ If not provided then Clarify job uses method "mean_abs"
325
333
use_logit (bool): Indicator of whether the logit function is to be applied to the model
326
334
predictions. Default is False. If "use_logit" is true then the SHAP values will
327
335
have log-odds units.
328
336
save_local_shap_values (bool): Indicator of whether to save the local SHAP values
329
337
in the output location. Default is True.
330
338
seed (int): seed value to get deterministic SHAP values. Default is None.
339
+ num_clusters (None or int): If a baseline is not provided, Clarify automatically
340
+ computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
341
+ num_clusters is a parameter for this algorithm. num_clusters will be the resulting
342
+ size of the baseline dataset. If not provided, Clarify job will use a default value.
331
343
"""
332
- if agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
344
+ if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
333
345
raise ValueError (
334
346
f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
335
347
)
336
-
348
+ if num_clusters is not None and baseline is not None :
349
+ raise ValueError (
350
+ "Baseline and num_clusters cannot be provided together. "
351
+ "Please specify one of the two."
352
+ )
337
353
self .shap_config = {
338
- "baseline" : baseline ,
339
- "num_samples" : num_samples ,
340
- "agg_method" : agg_method ,
341
354
"use_logit" : use_logit ,
342
355
"save_local_shap_values" : save_local_shap_values ,
343
356
}
357
+ if baseline is not None :
358
+ self .shap_config ["baseline" ] = baseline
359
+ if num_samples is not None :
360
+ self .shap_config ["num_samples" ] = num_samples
361
+ if agg_method is not None :
362
+ self .shap_config ["agg_method" ] = agg_method
344
363
if seed is not None :
345
364
self .shap_config ["seed" ] = seed
365
+ if num_clusters is not None :
366
+ self .shap_config ["num_clusters" ] = num_clusters
346
367
347
368
def get_explainability_config (self ):
348
369
"""Returns config."""
@@ -473,7 +494,7 @@ def _run(
473
494
json .dump (analysis_config , f )
474
495
s3_analysis_config_file = _upload_analysis_config (
475
496
analysis_config_file ,
476
- data_config .s3_output_path ,
497
+ data_config .s3_analysis_config_output_path or data_config . s3_output_path ,
477
498
self .sagemaker_session ,
478
499
kms_key ,
479
500
)
0 commit comments