@@ -48,12 +48,17 @@ def __init__(
48
48
headers (list[str]): A list of column names in the input dataset.
49
49
features (str): JSONPath for locating the feature columns for bias metrics if the
50
50
dataset format is JSONLines.
51
- dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV
52
- and "application/jsonlines" for JSONLines.
51
+ dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV,
52
+ "application/jsonlines" for JSONLines, and "application/x-parquet" for Parquet .
53
53
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
54
54
"ShardedByS3Key".
55
55
s3_compression_type (str): Valid options are "None" or "Gzip".
56
56
"""
57
+ if dataset_type not in ["text/csv" , "application/jsonlines" , "application/x-parquet" ]:
58
+ raise ValueError (
59
+ f"Invalid dataset_type '{ dataset_type } '."
60
+ f" Please check the API documentation for the supported dataset types."
61
+ )
57
62
self .s3_data_input_path = s3_data_input_path
58
63
self .s3_output_path = s3_output_path
59
64
self .s3_data_distribution_type = s3_data_distribution_type
@@ -508,7 +513,7 @@ def run_pre_training_bias(
508
513
kms_key = None ,
509
514
experiment_config = None ,
510
515
):
511
- """Runs a ProcessingJob to compute the requested bias 'methods' of the input data.
516
+ """Runs a ProcessingJob to compute the requested pre-training bias 'methods' of the input data.
512
517
513
518
Computes the requested methods that compare 'methods' (e.g. fraction of examples) for the
514
519
sensitive group vs the other examples.
@@ -562,7 +567,7 @@ def run_post_training_bias(
562
567
kms_key = None ,
563
568
experiment_config = None ,
564
569
):
565
- """Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
570
+ """Runs a ProcessingJob to compute the requested post-training bias 'methods' of the model predictions.
566
571
567
572
Spins up a model endpoint, runs inference over the input example in the
568
573
's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
@@ -635,10 +640,9 @@ def run_bias(
635
640
):
636
641
"""Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
637
642
638
- Spins up a model endpoint, runs inference over the input example in the
639
- 's3_data_input_path' to obtain predicted labels. Computes a the requested methods that
640
- compare 'methods' (e.g. accuracy, precision, recall) for the sensitive group vs the other
641
- examples.
643
+ The job will compute the metrics of both the pre-training methods and the post-training method.
644
+ To calculate post-training methods, it needs to spin up a model endpoint, runs inference over
645
+ the input example in the 's3_data_input_path' to obtain predicted labels.
642
646
643
647
Args:
644
648
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
0 commit comments