diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index ab6fed6d80..5d6f62b8d9 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -38,6 +38,7 @@ def __init__( dataset_type="text/csv", s3_data_distribution_type="FullyReplicated", s3_compression_type="None", + joinsource=None, ): """Initializes a configuration of both input and output datasets. @@ -57,6 +58,11 @@ def __init__( s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key". s3_compression_type (str): Valid options are "None" or "Gzip". + joinsource (str): The name or index of the column in the dataset that acts an + identifier column (for instance, while performing a join). This column is only + used as an identifier, and not used for any other computations. This is an + optional field in all cases except when the dataset contains more than one file, + and `save_local_shap_values` is set to true in SHAPConfig. """ if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]: raise ValueError( @@ -77,6 +83,7 @@ def __init__( _set(features, "features", self.analysis_config) _set(headers, "headers", self.analysis_config) _set(label, "label", self.analysis_config) + _set(joinsource, "joinsource_name_or_index", self.analysis_config) def get_config(self): """Returns part of an analysis config dictionary.""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 0b2bf1b2ec..251f352a87 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -379,13 +379,9 @@ def data_config(): s3_data_input_path="s3://input/train.csv", s3_output_path="s3://output/analysis_test_result", label="Label", - headers=[ - "Label", - "F1", - "F2", - "F3", - ], + headers=["Label", "F1", "F2", "F3", "F4"], dataset_type="text/csv", + joinsource="F4", ) @@ -455,7 +451,9 @@ def test_pre_training_bias( "F1", "F2", "F3", + "F4", ], + "joinsource_name_or_index": "F4", "label": "Label", "label_values_or_threshold": [1], "facet": [{"name_or_index": "F1"}], @@ -516,9 +514,11 @@ def test_post_training_bias( "F1", "F2", "F3", + "F4", ], "label": "Label", "label_values_or_threshold": [1], + "joinsource_name_or_index": "F4", "facet": [{"name_or_index": "F1"}], "group_variable": "F2", "methods": {"post_training_bias": {"methods": "all"}}, @@ -646,8 +646,25 @@ def _run_test_explain( "F1", "F2", "F3", + "F4", ], "label": "Label", + "joinsource_name_or_index": "F4", + "methods": { + "shap": { + "baseline": [ + [ + 0.26124998927116394, + 0.2824999988079071, + 0.06875000149011612, + ] + ], + "num_samples": 100, + "agg_method": "mean_sq", + "use_logit": False, + "save_local_shap_values": True, + } + }, "predictor": expected_predictor_config, } expected_explanation_configs = {}