Skip to content

Commit 6361ad6

Browse files
Merge branch 'master' into feature/CompilationStep
2 parents de1719e + 9be4c8a commit 6361ad6

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

src/sagemaker/clarify.py

+7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
dataset_type="text/csv",
3939
s3_data_distribution_type="FullyReplicated",
4040
s3_compression_type="None",
41+
joinsource=None,
4142
):
4243
"""Initializes a configuration of both input and output datasets.
4344
@@ -57,6 +58,11 @@ def __init__(
5758
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
5859
"ShardedByS3Key".
5960
s3_compression_type (str): Valid options are "None" or "Gzip".
61+
joinsource (str): The name or index of the column in the dataset that acts an
62+
identifier column (for instance, while performing a join). This column is only
63+
used as an identifier, and not used for any other computations. This is an
64+
optional field in all cases except when the dataset contains more than one file,
65+
and `save_local_shap_values` is set to true in SHAPConfig.
6066
"""
6167
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
6268
raise ValueError(
@@ -77,6 +83,7 @@ def __init__(
7783
_set(features, "features", self.analysis_config)
7884
_set(headers, "headers", self.analysis_config)
7985
_set(label, "label", self.analysis_config)
86+
_set(joinsource, "joinsource_name_or_index", self.analysis_config)
8087

8188
def get_config(self):
8289
"""Returns part of an analysis config dictionary."""

tests/unit/test_clarify.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,9 @@ def data_config():
379379
s3_data_input_path="s3://input/train.csv",
380380
s3_output_path="s3://output/analysis_test_result",
381381
label="Label",
382-
headers=[
383-
"Label",
384-
"F1",
385-
"F2",
386-
"F3",
387-
],
382+
headers=["Label", "F1", "F2", "F3", "F4"],
388383
dataset_type="text/csv",
384+
joinsource="F4",
389385
)
390386

391387

@@ -455,7 +451,9 @@ def test_pre_training_bias(
455451
"F1",
456452
"F2",
457453
"F3",
454+
"F4",
458455
],
456+
"joinsource_name_or_index": "F4",
459457
"label": "Label",
460458
"label_values_or_threshold": [1],
461459
"facet": [{"name_or_index": "F1"}],
@@ -516,9 +514,11 @@ def test_post_training_bias(
516514
"F1",
517515
"F2",
518516
"F3",
517+
"F4",
519518
],
520519
"label": "Label",
521520
"label_values_or_threshold": [1],
521+
"joinsource_name_or_index": "F4",
522522
"facet": [{"name_or_index": "F1"}],
523523
"group_variable": "F2",
524524
"methods": {"post_training_bias": {"methods": "all"}},
@@ -646,8 +646,25 @@ def _run_test_explain(
646646
"F1",
647647
"F2",
648648
"F3",
649+
"F4",
649650
],
650651
"label": "Label",
652+
"joinsource_name_or_index": "F4",
653+
"methods": {
654+
"shap": {
655+
"baseline": [
656+
[
657+
0.26124998927116394,
658+
0.2824999988079071,
659+
0.06875000149011612,
660+
]
661+
],
662+
"num_samples": 100,
663+
"agg_method": "mean_sq",
664+
"use_logit": False,
665+
"save_local_shap_values": True,
666+
}
667+
},
651668
"predictor": expected_predictor_config,
652669
}
653670
expected_explanation_configs = {}

0 commit comments

Comments
 (0)