Skip to content

Commit 87b634c

Browse files
change: Update BiasConfig to accept multiple facet params (aws#2243)
Co-authored-by: Georgios Schinas <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent 3bce78f commit 87b634c

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

src/sagemaker/clarify.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,34 @@ def __init__(
8888
Args:
8989
label_values_or_threshold (Any): List of label values or threshold to indicate positive
9090
outcome used for bias metrics.
91-
facet_name (str): Sensitive attribute in the input data for which we like to compare
92-
metrics.
91+
facet_name (str or [str]): String or List of strings of sensitive attribute(s) in the
92+
input data for which we like to compare metrics.
9393
facet_values_or_threshold (list): Optional list of values to form a sensitive group or
9494
threshold for a numeric facet column that defines the lower bound of a sensitive
9595
group. Defaults to considering each possible value as sensitive group and
9696
computing metrics vs all the other examples.
97+
If facet_name is a list, this needs to be None or a List consisting of lists or None
98+
with the same length as facet_name list.
9799
group_name (str): Optional column name or index to indicate a group column to be used
98100
for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or
99101
'Conditional Demographic Disparity in Predicted Labels - CDDPL'.
100102
"""
101-
facet = {"name_or_index": facet_name}
102-
_set(facet_values_or_threshold, "value_or_threshold", facet)
103+
if isinstance(facet_name, str):
104+
facet = {"name_or_index": facet_name}
105+
_set(facet_values_or_threshold, "value_or_threshold", facet)
106+
facet_list = [facet]
107+
elif facet_values_or_threshold is None or len(facet_name) == len(facet_values_or_threshold):
108+
facet_list = []
109+
for i, single_facet_name in enumerate(facet_name):
110+
facet = {"name_or_index": single_facet_name}
111+
if facet_values_or_threshold is not None:
112+
_set(facet_values_or_threshold[i], "value_or_threshold", facet)
113+
facet_list.append(facet)
114+
else:
115+
raise ValueError("Wrong combination of argument values passed")
103116
self.analysis_config = {
104117
"label_values_or_threshold": label_values_or_threshold,
105-
"facet": [facet],
118+
"facet": facet_list,
106119
}
107120
_set(group_name, "group_variable", self.analysis_config)
108121

tests/unit/test_clarify.py

+48
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,54 @@ def test_data_bias_config():
8989
assert expected_config == data_bias_config.get_config()
9090

9191

92+
def test_data_bias_config_multi_facet():
93+
label_values = [1]
94+
facet_name = ["Facet1", "Facet2"]
95+
facet_threshold = [[0], [1, 2]]
96+
group_name = "A151"
97+
98+
data_bias_config = BiasConfig(
99+
label_values_or_threshold=label_values,
100+
facet_name=facet_name,
101+
facet_values_or_threshold=facet_threshold,
102+
group_name=group_name,
103+
)
104+
105+
expected_config = {
106+
"label_values_or_threshold": label_values,
107+
"facet": [
108+
{"name_or_index": facet_name[0], "value_or_threshold": facet_threshold[0]},
109+
{"name_or_index": facet_name[1], "value_or_threshold": facet_threshold[1]},
110+
],
111+
"group_variable": group_name,
112+
}
113+
assert expected_config == data_bias_config.get_config()
114+
115+
116+
def test_data_bias_config_multi_facet_not_all_with_value():
117+
label_values = [1]
118+
facet_name = ["Facet1", "Facet2"]
119+
facet_threshold = [[0], None]
120+
group_name = "A151"
121+
122+
data_bias_config = BiasConfig(
123+
label_values_or_threshold=label_values,
124+
facet_name=facet_name,
125+
facet_values_or_threshold=facet_threshold,
126+
group_name=group_name,
127+
)
128+
129+
expected_config = {
130+
"label_values_or_threshold": label_values,
131+
"facet": [
132+
{"name_or_index": facet_name[0], "value_or_threshold": facet_threshold[0]},
133+
{"name_or_index": facet_name[1]},
134+
],
135+
"group_variable": group_name,
136+
}
137+
assert expected_config == data_bias_config.get_config()
138+
139+
92140
def test_model_config():
93141
model_name = "xgboost-model"
94142
instance_type = "ml.c5.xlarge"

0 commit comments

Comments
 (0)