@@ -45,6 +45,12 @@ def __init__(
45
45
dataset_type = "text/csv" ,
46
46
s3_compression_type = "None" ,
47
47
joinsource = None ,
48
+ facet_dataset_uri = None ,
49
+ facet_headers = None ,
50
+ predicted_label_dataset_uri = None ,
51
+ predicted_label_headers = None ,
52
+ predicted_label = None ,
53
+ excluded_columns = None ,
48
54
):
49
55
"""Initializes a configuration of both input and output datasets.
50
56
@@ -54,22 +60,57 @@ def __init__(
54
60
s3_analysis_config_output_path (str): S3 prefix to store the analysis config output.
55
61
If this field is None, then the ``s3_output_path`` will be used
56
62
to store the ``analysis_config`` output.
57
- label (str): Target attribute of the model ** required** for bias metrics (both pre-
58
- and post-training). Optional when running SHAP explainability .
59
- Specified as column name or index for CSV dataset, or as JSONPath for JSONLines .
60
- headers (list[str]): A list of column names in the input dataset .
63
+ label (str): Target attribute of the model required by bias metrics.
64
+ Specified as column name or index for CSV dataset or as JSONPath for JSONLines .
65
+ *Required parameter* except for when the input dataset does not contain the label .
66
+ Cannot be used at the same time as ``predicted_label`` .
61
67
features (str): JSONPath for locating the feature columns for bias metrics if the
62
68
dataset format is JSONLines.
63
69
dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
64
70
``"application/jsonlines"`` for JSONLines, and
65
71
``"application/x-parquet"`` for Parquet.
66
72
s3_compression_type (str): Valid options are "None" or ``"Gzip"``.
67
- joinsource (str): The name or index of the column in the dataset that acts as an
68
- identifier column (for instance, while performing a join). This column is only
69
- used as an identifier, and not used for any other computations. This is an
70
- optional field in all cases except when the dataset contains more than one file,
71
- and ``save_local_shap_values`` is set to True
72
- in :class:`~sagemaker.clarify.SHAPConfig`.
73
+ joinsource (str or int): The name or index of the column in the dataset that
74
+ acts as an identifier column (for instance, while performing a join).
75
+ This column is only used as an identifier, and not used for any other computations.
76
+ This is an optional field in all cases except:
77
+
78
+ * The dataset contains more than one file and `save_local_shap_values`
79
+ is set to true in :class:`~sagemaker.clarify.ShapConfig`, and/or
80
+ * When the dataset and/or facet dataset and/or predicted label dataset
81
+ are in separate files.
82
+
83
+ facet_dataset_uri (str): Dataset S3 prefix/object URI that contains facet attribute(s),
84
+ used for bias analysis on datasets without facets.
85
+
86
+ * If the dataset and the facet dataset are one single file each, then
87
+ the original dataset and facet dataset must have the same number of rows.
88
+ * If the dataset and facet dataset are in multiple files (either one), then
89
+ an index column, ``joinsource``, is required to join the two datasets.
90
+
91
+ Clarify will not use the ``joinsource`` column and columns present in the facet
92
+ dataset when calling model inference APIs.
93
+ facet_headers (list[str]): List of column names in the facet dataset.
94
+ predicted_label_dataset_uri (str): Dataset S3 prefix/object URI with predicted labels,
95
+ which are used directly for analysis instead of making model inference API calls.
96
+
97
+ * If the dataset and the predicted label dataset are one single file each, then the
98
+ original dataset and predicted label dataset must have the same number of rows.
99
+ * If the dataset and predicted label dataset are in multiple files (either one),
100
+ then an index column, ``joinsource``, is required to join the two datasets.
101
+
102
+ predicted_label_headers (list[str]): List of column names in the predicted label dataset
103
+ predicted_label (str or int): Predicted label of the target attribute of the model
104
+ required for running bias analysis. Specified as column name or index for CSV data.
105
+ Clarify uses the predicted labels directly instead of making model inference API
106
+ calls. Cannot be used at the same time as ``label``.
107
+ excluded_columns (list[int] or list[str]): A list of names or indices of the columns
108
+ which are to be excluded from making model inference API calls.
109
+
110
+ Raises:
111
+ ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters
112
+ are used with un-supported ``dataset_type``, or facet dataset parameters
113
+ are used with un-supported ``dataset_type``
73
114
"""
74
115
if dataset_type not in [
75
116
"text/csv" ,
@@ -81,6 +122,32 @@ def __init__(
81
122
f"Invalid dataset_type '{ dataset_type } '."
82
123
f" Please check the API documentation for the supported dataset types."
83
124
)
125
+ # parameters for analysis on datasets without facets are only supported for CSV datasets
126
+ if dataset_type != "text/csv" :
127
+ if predicted_label :
128
+ raise ValueError (
129
+ f"The parameter 'predicted_label' is not supported"
130
+ f" for dataset_type '{ dataset_type } '."
131
+ f" Please check the API documentation for the supported dataset types."
132
+ )
133
+ if excluded_columns :
134
+ raise ValueError (
135
+ f"The parameter 'excluded_columns' is not supported"
136
+ f" for dataset_type '{ dataset_type } '."
137
+ f" Please check the API documentation for the supported dataset types."
138
+ )
139
+ if facet_dataset_uri or facet_headers :
140
+ raise ValueError (
141
+ f"The parameters 'facet_dataset_uri' and 'facet_headers'"
142
+ f" are not supported for dataset_type '{ dataset_type } '."
143
+ f" Please check the API documentation for the supported dataset types."
144
+ )
145
+ if predicted_label_dataset_uri or predicted_label_headers :
146
+ raise ValueError (
147
+ f"The parameters 'predicted_label_dataset_uri' and 'predicted_label_headers'"
148
+ f" are not supported for dataset_type '{ dataset_type } '."
149
+ f" Please check the API documentation for the supported dataset types."
150
+ )
84
151
self .s3_data_input_path = s3_data_input_path
85
152
self .s3_output_path = s3_output_path
86
153
self .s3_analysis_config_output_path = s3_analysis_config_output_path
@@ -89,13 +156,25 @@ def __init__(
89
156
self .label = label
90
157
self .headers = headers
91
158
self .features = features
159
+ self .facet_dataset_uri = facet_dataset_uri
160
+ self .facet_headers = facet_headers
161
+ self .predicted_label_dataset_uri = predicted_label_dataset_uri
162
+ self .predicted_label_headers = predicted_label_headers
163
+ self .predicted_label = predicted_label
164
+ self .excluded_columns = excluded_columns
92
165
self .analysis_config = {
93
166
"dataset_type" : dataset_type ,
94
167
}
95
168
_set (features , "features" , self .analysis_config )
96
169
_set (headers , "headers" , self .analysis_config )
97
170
_set (label , "label" , self .analysis_config )
98
171
_set (joinsource , "joinsource_name_or_index" , self .analysis_config )
172
+ _set (facet_dataset_uri , "facet_dataset_uri" , self .analysis_config )
173
+ _set (facet_headers , "facet_headers" , self .analysis_config )
174
+ _set (predicted_label_dataset_uri , "predicted_label_dataset_uri" , self .analysis_config )
175
+ _set (predicted_label_headers , "predicted_label_headers" , self .analysis_config )
176
+ _set (predicted_label , "predicted_label" , self .analysis_config )
177
+ _set (excluded_columns , "excluded_columns" , self .analysis_config )
99
178
100
179
def get_config (self ):
101
180
"""Returns part of an analysis config dictionary."""
@@ -205,21 +284,23 @@ def __init__(
205
284
r"""Initializes a configuration of a model and the endpoint to be created for it.
206
285
207
286
Args:
208
- model_name (str): Model name (as created by 'CreateModel').
287
+ model_name (str): Model name (as created by
288
+ `CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_.
209
289
instance_count (int): The number of instances of a new endpoint for model inference.
210
- instance_type (str): The type of EC2 instance to use for model inference,
211
- for example, ``"ml.c5.xlarge"``.
290
+ instance_type (str): The type of
291
+ `EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
292
+ to use for model inference; for example, ``"ml.c5.xlarge"``.
212
293
accept_type (str): The model output format to be used for getting inferences with the
213
- shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines".
214
- Default is the same as content_type.
294
+ shadow endpoint. Valid values are `` "text/csv"`` for CSV and
295
+ ``"application/jsonlines"``. Default is the same as `` content_type`` .
215
296
content_type (str): The model input format to be used for getting inferences with the
216
- shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines".
217
- Default is the same as dataset format .
297
+ shadow endpoint. Valid values are `` "text/csv"`` for CSV and
298
+ ``"application/jsonlines"``. Default is the same as ``dataset_format`` .
218
299
content_template (str): A template string to be used to construct the model input from
219
300
dataset instances. It is only used when ``model_content_type`` is
220
301
``"application/jsonlines"``. The template should have one and only one placeholder,
221
- "features", which will be replaced by a features list to form the model inference
222
- input.
302
+ `` "features"`` , which will be replaced by a features list to form the model
303
+ inference input.
223
304
custom_attributes (str): Provides additional information about a request for an
224
305
inference submitted to a model hosted at an Amazon SageMaker endpoint. The
225
306
information is an opaque value that is forwarded verbatim. You could use this
@@ -509,16 +590,20 @@ def __init__(
509
590
for these units.
510
591
language (str): Specifies the language of the text features. Accepted values are
511
592
one of the following:
512
- "chinese", "danish", "dutch", "english", "french", "german", "greek", "italian",
513
- "japanese", "lithuanian", "multi-language", "norwegian bokmål", "polish",
514
- "portuguese", "romanian", "russian", "spanish", "afrikaans", "albanian", "arabic",
515
- "armenian", "basque", "bengali", "bulgarian", "catalan", "croatian", "czech",
516
- "estonian", "finnish", "gujarati", "hebrew", "hindi", "hungarian", "icelandic",
517
- "indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian", "luxembourgish",
518
- "macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit", "serbian",
519
- "setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil",
520
- "tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba".
521
- Use "multi-language" for a mix of multiple languages.
593
+ ``"chinese"``, ``"danish"``, ``"dutch"``, ``"english"``, ``"french"``, ``"german"``,
594
+ ``"greek"``, ``"italian"``, ``"japanese"``, ``"lithuanian"``, ``"multi-language"``,
595
+ ``"norwegian bokmål"``, ``"polish"``, ``"portuguese"``, ``"romanian"``,
596
+ ``"russian"``, ``"spanish"``, ``"afrikaans"``, ``"albanian"``, ``"arabic"``,
597
+ ``"armenian"``, ``"basque"``, ``"bengali"``, ``"bulgarian"``, ``"catalan"``,
598
+ ``"croatian"``, ``"czech"``, ``"estonian"``, ``"finnish"``, ``"gujarati"``,
599
+ ``"hebrew"``, ``"hindi"``, ``"hungarian"``, ``"icelandic"``, ``"indonesian"``,
600
+ ``"irish"``, ``"kannada"``, ``"kyrgyz"``, ``"latvian"``, ``"ligurian"``,
601
+ ``"luxembourgish"``, ``"macedonian"``, ``"malayalam"``, ``"marathi"``, ``"nepali"``,
602
+ ``"persian"``, ``"sanskrit"``, ``"serbian"``, ``"setswana"``, ``"sinhala"``,
603
+ ``"slovak"``, ``"slovenian"``, ``"swedish"``, ``"tagalog"``, ``"tamil"``,
604
+ ``"tatar"``, ``"telugu"``, ``"thai"``, ``"turkish"``, ``"ukrainian"``, ``"urdu"``,
605
+ ``"vietnamese"``, ``"yoruba"``.
606
+ Use ``"multi-language"`` for a mix of multiple languages.
522
607
523
608
Raises:
524
609
ValueError: when ``granularity`` is not in list of supported values
@@ -742,12 +827,15 @@ def __init__(
742
827
data stored in Amazon S3.
743
828
instance_count (int): The number of instances to run
744
829
a processing job with.
745
- instance_type (str): The type of EC2 instance to use for
746
- processing, for example, ``'ml.c4.xlarge'``.
747
- volume_size_in_gb (int): Size in GB of the EBS volume
748
- to use for storing data during processing (default: 30).
749
- volume_kms_key (str): A KMS key for the processing
750
- volume (default: None).
830
+ instance_type (str): The type of
831
+ `EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
832
+ to use for model inference; for example, ``"ml.c5.xlarge"``.
833
+ volume_size_in_gb (int): Size in GB of the
834
+ `EBS volume <https://docs.aws.amazon.com/sagemaker/latest/dg/host-instance-storage.html>`_.
835
+ to use for storing data during processing (default: 30 GB).
836
+ volume_kms_key (str): A
837
+ `KMS key <https://docs.aws.amazon.com/sagemaker/latest/dg/key-management.html>`_
838
+ for the processing volume (default: None).
751
839
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
752
840
max_runtime_in_seconds (int): Timeout in seconds (default: None).
753
841
After this amount of time, Amazon SageMaker terminates the job,
@@ -769,7 +857,7 @@ def __init__(
769
857
inter-container traffic, security group IDs, and subnets.
770
858
job_name_prefix (str): Processing job name prefix.
771
859
version (str): Clarify version to use.
772
- """
860
+ """ # noqa E501 # pylint: disable=c0301
773
861
container_uri = image_uris .retrieve ("clarify" , sagemaker_session .boto_region_name , version )
774
862
self .job_name_prefix = job_name_prefix
775
863
super (SageMakerClarifyProcessor , self ).__init__ (
@@ -1163,6 +1251,7 @@ def run_explainability(
1163
1251
1164
1252
Currently, only SHAP and Partial Dependence Plots (PDP) are supported
1165
1253
as explainability methods.
1254
+ You can request both methods or one at a time with the ``explainability_config`` parameter.
1166
1255
1167
1256
When SHAP is requested in the ``explainability_config``,
1168
1257
the SHAP algorithm calculates the feature importance for each input example
@@ -1188,6 +1277,8 @@ def run_explainability(
1188
1277
Config of the specific explainability method or a list of
1189
1278
:class:`~sagemaker.clarify.ExplainabilityConfig` objects.
1190
1279
Currently, SHAP and PDP are the two methods supported.
1280
+ You can request multiple methods at once by passing in a list of
1281
+ `~sagemaker.clarify.ExplainabilityConfig`.
1191
1282
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
1192
1283
Index or JSONPath to locate the predicted scores in the model output. This is not
1193
1284
required if the model output is a single score. Alternatively, it can be an instance
0 commit comments