@@ -37,6 +37,7 @@ def __init__(
37
37
dataset_type = "text/csv" ,
38
38
s3_data_distribution_type = "FullyReplicated" ,
39
39
s3_compression_type = "None" ,
40
+ joinsource = None ,
40
41
):
41
42
"""Initializes a configuration of both input and output datasets.
42
43
@@ -53,6 +54,11 @@ def __init__(
53
54
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
54
55
"ShardedByS3Key".
55
56
s3_compression_type (str): Valid options are "None" or "Gzip".
57
+ joinsource (str): the name or index of the column in the dataset that acts an
58
+ identifier column (for instance, while performing a join). This column is only
59
+ used as an identifier, and not used for any other computations. This is an
60
+ optional field in all cases except when the dataset contains more than one file,
61
+ and `save_local_shap_values` is set to true in SHAPConfig.
56
62
"""
57
63
if dataset_type not in ["text/csv" , "application/jsonlines" , "application/x-parquet" ]:
58
64
raise ValueError (
@@ -72,6 +78,7 @@ def __init__(
72
78
_set (features , "features" , self .analysis_config )
73
79
_set (headers , "headers" , self .analysis_config )
74
80
_set (label , "label" , self .analysis_config )
81
+ _set (joinsource , "joinsource_name_or_index" , self .analysis_config )
75
82
76
83
def get_config (self ):
77
84
"""Returns part of an analysis config dictionary."""
0 commit comments