@@ -48,12 +48,17 @@ def __init__(
48
48
headers (list[str]): A list of column names in the input dataset.
49
49
features (str): JSONPath for locating the feature columns for bias metrics if the
50
50
dataset format is JSONLines.
51
- dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV
52
- and "application/jsonlines" for JSONLines.
51
+ dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV,
52
+ "application/jsonlines" for JSONLines, and "application/x-parquet" for Parquet .
53
53
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
54
54
"ShardedByS3Key".
55
55
s3_compression_type (str): Valid options are "None" or "Gzip".
56
56
"""
57
+ if dataset_type not in ["text/csv" , "application/jsonlines" , "application/x-parquet" ]:
58
+ raise ValueError (
59
+ f"Invalid dataset_type '{ dataset_type } '."
60
+ f" Please check the API documentation for the supported dataset types."
61
+ )
57
62
self .s3_data_input_path = s3_data_input_path
58
63
self .s3_output_path = s3_output_path
59
64
self .s3_data_distribution_type = s3_data_distribution_type
@@ -508,7 +513,7 @@ def run_pre_training_bias(
508
513
kms_key = None ,
509
514
experiment_config = None ,
510
515
):
511
- """Runs a ProcessingJob to compute the requested bias ' methods' of the input data.
516
+ """Runs a ProcessingJob to compute the pre-training bias methods of the input data.
512
517
513
518
Computes the requested methods that compare 'methods' (e.g. fraction of examples) for the
514
519
sensitive group vs the other examples.
@@ -517,14 +522,14 @@ def run_pre_training_bias(
517
522
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
518
523
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
519
524
methods (str or list[str]): Selector of a subset of potential metrics:
520
- ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-ci .html>`_",
521
- "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-dpl .html>`_",
522
- "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-kl.html>`_",
523
- "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-js .html>`_",
524
- "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-lp.html>`_",
525
- "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-tvd .html>`_",
526
- "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-ks .html>`_",
527
- "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-cdd .html>`_"].
525
+ ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance .html>`_",
526
+ "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-true-label-imbalance .html>`_",
527
+ "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-kl-divergence .html>`_",
528
+ "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-jensen-shannon-divergence .html>`_",
529
+ "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-lp-norm .html>`_",
530
+ "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-total-variation-distance .html>`_",
531
+ "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-kolmogorov-smirnov .html>`_",
532
+ "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-cddl .html>`_"].
528
533
Defaults to computing all.
529
534
wait (bool): Whether the call should wait until the job completes (default: True).
530
535
logs (bool): Whether to show the logs produced by the job.
@@ -538,7 +543,7 @@ def run_pre_training_bias(
538
543
experiment_config (dict[str, str]): Experiment management configuration.
539
544
Dictionary contains three optional keys:
540
545
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
541
- """
546
+ """ # noqa E501
542
547
analysis_config = data_config .get_config ()
543
548
analysis_config .update (data_bias_config .get_config ())
544
549
analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
@@ -562,7 +567,7 @@ def run_post_training_bias(
562
567
kms_key = None ,
563
568
experiment_config = None ,
564
569
):
565
- """Runs a ProcessingJob to compute the requested bias ' methods' of the model predictions.
570
+ """Runs a ProcessingJob to compute the post-training bias methods of the model predictions.
566
571
567
572
Spins up a model endpoint, runs inference over the input example in the
568
573
's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
@@ -633,12 +638,11 @@ def run_bias(
633
638
kms_key = None ,
634
639
experiment_config = None ,
635
640
):
636
- """Runs a ProcessingJob to compute the requested bias ' methods' of the model predictions .
641
+ """Runs a ProcessingJob to compute the requested bias methods.
637
642
638
- Spins up a model endpoint, runs inference over the input example in the
639
- 's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
640
- compare 'methods' (e.g. accuracy, precision, recall) for the sensitive group vs the other
641
- examples.
643
+ It computes the metrics of both the pre-training methods and the post-training methods.
644
+ To calculate post-training methods, it needs to spin up a model endpoint, runs inference
645
+ over the input example in the 's3_data_input_path' to obtain predicted labels.
642
646
643
647
Args:
644
648
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
@@ -648,14 +652,14 @@ def run_bias(
648
652
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
649
653
Config of how to extract the predicted label from the model output.
650
654
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
651
- ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-ci .html>`_",
652
- "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-dpl .html>`_",
653
- "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-kl.html>`_",
654
- "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-js .html>`_",
655
- "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-lp.html>`_",
656
- "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-tvd .html>`_",
657
- "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-ks .html>`_",
658
- "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training- bias-metric-cdd .html>`_"].
655
+ ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance .html>`_",
656
+ "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-true-label-imbalance .html>`_",
657
+ "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-kl-divergence .html>`_",
658
+ "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-jensen-shannon-divergence .html>`_",
659
+ "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-lp-norm .html>`_",
660
+ "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-total-variation-distance .html>`_",
661
+ "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-kolmogorov-smirnov .html>`_",
662
+ "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data- bias-metric-cddl .html>`_"].
659
663
Defaults to computing all.
660
664
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
661
665
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
@@ -682,7 +686,7 @@ def run_bias(
682
686
experiment_config (dict[str, str]): Experiment management configuration.
683
687
Dictionary contains three optional keys:
684
688
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
685
- """
689
+ """ # noqa E501
686
690
analysis_config = data_config .get_config ()
687
691
analysis_config .update (bias_config .get_config ())
688
692
analysis_config ["predictor" ] = model_config .get_predictor_config ()
0 commit comments