Skip to content

Commit 1a2d36f

Browse files
documentation: Minor updates to Clarify API documentation (#2550)
Updated the docstring of run_bias method and state clearly that it computes both the pretraining bias metrics and the posttraining bias metrics. Updated the docstring of the constructor to list the supported dataset types, and add a validation at the entrypoint to fail fast if the user provided an unsupported dataset type. Fix broken links to bias metrics documentation. Co-authored-by: Ahsan Khan <[email protected]>
1 parent e850a15 commit 1a2d36f

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

src/sagemaker/clarify.py

+31-27
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,17 @@ def __init__(
4848
headers (list[str]): A list of column names in the input dataset.
4949
features (str): JSONPath for locating the feature columns for bias metrics if the
5050
dataset format is JSONLines.
51-
dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV
52-
and "application/jsonlines" for JSONLines.
51+
dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV,
52+
"application/jsonlines" for JSONLines, and "application/x-parquet" for Parquet.
5353
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
5454
"ShardedByS3Key".
5555
s3_compression_type (str): Valid options are "None" or "Gzip".
5656
"""
57+
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
58+
raise ValueError(
59+
f"Invalid dataset_type '{dataset_type}'."
60+
f" Please check the API documentation for the supported dataset types."
61+
)
5762
self.s3_data_input_path = s3_data_input_path
5863
self.s3_output_path = s3_output_path
5964
self.s3_data_distribution_type = s3_data_distribution_type
@@ -508,7 +513,7 @@ def run_pre_training_bias(
508513
kms_key=None,
509514
experiment_config=None,
510515
):
511-
"""Runs a ProcessingJob to compute the requested bias 'methods' of the input data.
516+
"""Runs a ProcessingJob to compute the pre-training bias methods of the input data.
512517
513518
Computes the requested methods that compare 'methods' (e.g. fraction of examples) for the
514519
sensitive group vs the other examples.
@@ -517,14 +522,14 @@ def run_pre_training_bias(
517522
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
518523
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
519524
methods (str or list[str]): Selector of a subset of potential metrics:
520-
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
521-
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
522-
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
523-
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
524-
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
525-
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
526-
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
527-
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
525+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
526+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
527+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
528+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
529+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
530+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
531+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
532+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
528533
Defaults to computing all.
529534
wait (bool): Whether the call should wait until the job completes (default: True).
530535
logs (bool): Whether to show the logs produced by the job.
@@ -538,7 +543,7 @@ def run_pre_training_bias(
538543
experiment_config (dict[str, str]): Experiment management configuration.
539544
Dictionary contains three optional keys:
540545
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
541-
"""
546+
""" # noqa E501
542547
analysis_config = data_config.get_config()
543548
analysis_config.update(data_bias_config.get_config())
544549
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
@@ -562,7 +567,7 @@ def run_post_training_bias(
562567
kms_key=None,
563568
experiment_config=None,
564569
):
565-
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
570+
"""Runs a ProcessingJob to compute the post-training bias methods of the model predictions.
566571
567572
Spins up a model endpoint, runs inference over the input example in the
568573
's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
@@ -633,12 +638,11 @@ def run_bias(
633638
kms_key=None,
634639
experiment_config=None,
635640
):
636-
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
641+
"""Runs a ProcessingJob to compute the requested bias methods.
637642
638-
Spins up a model endpoint, runs inference over the input example in the
639-
's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
640-
compare 'methods' (e.g. accuracy, precision, recall) for the sensitive group vs the other
641-
examples.
643+
It computes the metrics of both the pre-training methods and the post-training methods.
644+
To calculate post-training methods, it needs to spin up a model endpoint, runs inference
645+
over the input example in the 's3_data_input_path' to obtain predicted labels.
642646
643647
Args:
644648
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
@@ -648,14 +652,14 @@ def run_bias(
648652
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
649653
Config of how to extract the predicted label from the model output.
650654
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
651-
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
652-
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
653-
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
654-
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
655-
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
656-
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
657-
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
658-
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
655+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
656+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
657+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
658+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
659+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
660+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
661+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
662+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
659663
Defaults to computing all.
660664
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
661665
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
@@ -682,7 +686,7 @@ def run_bias(
682686
experiment_config (dict[str, str]): Experiment management configuration.
683687
Dictionary contains three optional keys:
684688
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
685-
"""
689+
""" # noqa E501
686690
analysis_config = data_config.get_config()
687691
analysis_config.update(bias_config.get_config())
688692
analysis_config["predictor"] = model_config.get_predictor_config()

tests/unit/test_clarify.py

+9
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ def test_data_config():
6868
assert "FullyReplicated" == data_config.s3_data_distribution_type
6969

7070

71+
def test_invalid_data_config():
72+
with pytest.raises(ValueError, match=r"^Invalid dataset_type"):
73+
DataConfig(
74+
s3_data_input_path="s3://bucket/inputpath",
75+
s3_output_path="s3://bucket/outputpath",
76+
dataset_type="whatnot_type",
77+
)
78+
79+
7180
def test_data_bias_config():
7281
label_values = [1]
7382
facet_name = "F1"

0 commit comments

Comments
 (0)