Skip to content

Commit 3d703c4

Browse files
feature: add joinsource to DataConfig
1 parent b082eb4 commit 3d703c4

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/sagemaker/clarify.py

+7
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
dataset_type="text/csv",
3838
s3_data_distribution_type="FullyReplicated",
3939
s3_compression_type="None",
40+
joinsource=None,
4041
):
4142
"""Initializes a configuration of both input and output datasets.
4243
@@ -53,6 +54,11 @@ def __init__(
5354
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
5455
"ShardedByS3Key".
5556
s3_compression_type (str): Valid options are "None" or "Gzip".
57+
joinsource (str): the name or index of the column in the dataset that acts an
58+
identifier column (for instance, while performing a join). This column is only
59+
used as an identifier, and not used for any other computations. This is an
60+
optional field in all cases except when the dataset contains more than one file,
61+
and `save_local_shap_values` is set to true in SHAPConfig.
5662
"""
5763
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
5864
raise ValueError(
@@ -72,6 +78,7 @@ def __init__(
7278
_set(features, "features", self.analysis_config)
7379
_set(headers, "headers", self.analysis_config)
7480
_set(label, "label", self.analysis_config)
81+
_set(joinsource, "joinsource_name_or_index", self.analysis_config)
7582

7683
def get_config(self):
7784
"""Returns part of an analysis config dictionary."""

tests/unit/test_clarify.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,9 @@ def data_config():
326326
s3_data_input_path="s3://input/train.csv",
327327
s3_output_path="s3://output/analysis_test_result",
328328
label="Label",
329-
headers=[
330-
"Label",
331-
"F1",
332-
"F2",
333-
"F3",
334-
],
329+
headers=["Label", "F1", "F2", "F3", "F4"],
335330
dataset_type="text/csv",
331+
joinsource="F4",
336332
)
337333

338334

@@ -397,7 +393,9 @@ def test_pre_training_bias(
397393
"F1",
398394
"F2",
399395
"F3",
396+
"F4",
400397
],
398+
"joinsource_name_or_index": "F4",
401399
"label": "Label",
402400
"label_values_or_threshold": [1],
403401
"facet": [{"name_or_index": "F1"}],
@@ -458,9 +456,11 @@ def test_post_training_bias(
458456
"F1",
459457
"F2",
460458
"F3",
459+
"F4",
461460
],
462461
"label": "Label",
463462
"label_values_or_threshold": [1],
463+
"joinsource_name_or_index": "F4",
464464
"facet": [{"name_or_index": "F1"}],
465465
"group_variable": "F2",
466466
"methods": {"post_training_bias": {"methods": "all"}},
@@ -526,8 +526,10 @@ def _run_test_shap(
526526
"F1",
527527
"F2",
528528
"F3",
529+
"F4",
529530
],
530531
"label": "Label",
532+
"joinsource_name_or_index": "F4",
531533
"methods": {
532534
"shap": {
533535
"baseline": [

0 commit comments

Comments
 (0)