Skip to content

Commit 7073448

Browse files
committed
documentation: Minor updates to Clarify API documentation
Updated the docstring of run_bias method and state clearly that it computes both the pretraining bias metrics and the posttraining bias metrics. Updated the docstring of the constructor to list the supported dataset types, and add a validation at the entrypoint to fail fast if the user provided an unsupported dataset type.
1 parent 5c8ef31 commit 7073448

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

src/sagemaker/clarify.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,17 @@ def __init__(
4848
headers (list[str]): A list of column names in the input dataset.
4949
features (str): JSONPath for locating the feature columns for bias metrics if the
5050
dataset format is JSONLines.
51-
dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV
52-
and "application/jsonlines" for JSONLines.
51+
dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV,
52+
"application/jsonlines" for JSONLines, and "application/x-parquet" for Parquet.
5353
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
5454
"ShardedByS3Key".
5555
s3_compression_type (str): Valid options are "None" or "Gzip".
5656
"""
57+
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
58+
raise ValueError(
59+
f"Invalid dataset_type '{dataset_type}'."
60+
f" Please check the API documentation for the supported dataset types."
61+
)
5762
self.s3_data_input_path = s3_data_input_path
5863
self.s3_output_path = s3_output_path
5964
self.s3_data_distribution_type = s3_data_distribution_type
@@ -508,7 +513,7 @@ def run_pre_training_bias(
508513
kms_key=None,
509514
experiment_config=None,
510515
):
511-
"""Runs a ProcessingJob to compute the requested bias 'methods' of the input data.
516+
"""Runs a ProcessingJob to compute the requested pre-training bias 'methods' of the input data.
512517
513518
Computes the requested methods that compare 'methods' (e.g. fraction of examples) for the
514519
sensitive group vs the other examples.
@@ -562,7 +567,7 @@ def run_post_training_bias(
562567
kms_key=None,
563568
experiment_config=None,
564569
):
565-
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
570+
"""Runs a ProcessingJob to compute the requested post-training bias 'methods' of the model predictions.
566571
567572
Spins up a model endpoint, runs inference over the input example in the
568573
's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
@@ -635,10 +640,9 @@ def run_bias(
635640
):
636641
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
637642
638-
Spins up a model endpoint, runs inference over the input example in the
639-
's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
640-
compare 'methods' (e.g. accuracy, precision, recall) for the sensitive group vs the other
641-
examples.
643+
The job will compute the metrics of both the pre-training methods and the post-training method.
644+
To calculate post-training methods, it needs to spin up a model endpoint, runs inference over
645+
the input example in the 's3_data_input_path' to obtain predicted labels.
642646
643647
Args:
644648
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.

tests/unit/test_clarify.py

+9
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ def test_data_config():
6868
assert "FullyReplicated" == data_config.s3_data_distribution_type
6969

7070

71+
def test_invalid_data_config():
72+
with pytest.raises(ValueError, match=r"^Invalid dataset_type"):
73+
DataConfig(
74+
s3_data_input_path="s3://bucket/inputpath",
75+
s3_output_path="s3://bucket/outputpath",
76+
dataset_type="whatnot_type",
77+
)
78+
79+
7180
def test_data_bias_config():
7281
label_values = [1]
7382
facet_name = "F1"

0 commit comments

Comments
 (0)