diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index f8b83d5d28..d1bda766ee 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -48,12 +48,17 @@ def __init__( headers (list[str]): A list of column names in the input dataset. features (str): JSONPath for locating the feature columns for bias metrics if the dataset format is JSONLines. - dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV - and "application/jsonlines" for JSONLines. + dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV, + "application/jsonlines" for JSONLines, and "application/x-parquet" for Parquet. s3_data_distribution_type (str): Valid options are "FullyReplicated" or "ShardedByS3Key". s3_compression_type (str): Valid options are "None" or "Gzip". """ + if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]: + raise ValueError( + f"Invalid dataset_type '{dataset_type}'." + f" Please check the API documentation for the supported dataset types." + ) self.s3_data_input_path = s3_data_input_path self.s3_output_path = s3_output_path self.s3_data_distribution_type = s3_data_distribution_type @@ -508,7 +513,7 @@ def run_pre_training_bias( kms_key=None, experiment_config=None, ): - """Runs a ProcessingJob to compute the requested bias 'methods' of the input data. + """Runs a ProcessingJob to compute the pre-training bias methods of the input data. Computes the requested methods that compare 'methods' (e.g. fraction of examples) for the sensitive group vs the other examples. @@ -517,14 +522,14 @@ def run_pre_training_bias( data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. methods (str or list[str]): Selector of a subset of potential metrics: - ["`CI `_", - "`DPL `_", - "`KL `_", - "`JS `_", - "`LP `_", - "`TVD `_", - "`KS `_", - "`CDDL `_"]. + ["`CI `_", + "`DPL `_", + "`KL `_", + "`JS `_", + "`LP `_", + "`TVD `_", + "`KS `_", + "`CDDL `_"]. Defaults to computing all. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. @@ -538,7 +543,7 @@ def run_pre_training_bias( experiment_config (dict[str, str]): Experiment management configuration. Dictionary contains three optional keys: 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. - """ + """ # noqa E501 analysis_config = data_config.get_config() analysis_config.update(data_bias_config.get_config()) analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} @@ -562,7 +567,7 @@ def run_post_training_bias( kms_key=None, experiment_config=None, ): - """Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions. + """Runs a ProcessingJob to compute the post-training bias methods of the model predictions. Spins up a model endpoint, runs inference over the input example in the 's3_data_input_path' to obtain predicted labels. Computes a the requested methods that @@ -633,12 +638,11 @@ def run_bias( kms_key=None, experiment_config=None, ): - """Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions. + """Runs a ProcessingJob to compute the requested bias methods. - Spins up a model endpoint, runs inference over the input example in the - 's3_data_input_path' to obtain predicted labels. Computes a the requested methods that - compare 'methods' (e.g. accuracy, precision, recall) for the sensitive group vs the other - examples. + It computes the metrics of both the pre-training methods and the post-training methods. + To calculate post-training methods, it needs to spin up a model endpoint, runs inference + over the input example in the 's3_data_input_path' to obtain predicted labels. Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. @@ -648,14 +652,14 @@ def run_bias( model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`): Config of how to extract the predicted label from the model output. pre_training_methods (str or list[str]): Selector of a subset of potential metrics: - ["`CI `_", - "`DPL `_", - "`KL `_", - "`JS `_", - "`LP `_", - "`TVD `_", - "`KS `_", - "`CDDL `_"]. + ["`CI `_", + "`DPL `_", + "`KL `_", + "`JS `_", + "`LP `_", + "`TVD `_", + "`KS `_", + "`CDDL `_"]. Defaults to computing all. post_training_methods (str or list[str]): Selector of a subset of potential metrics: ["`DPPL `_" @@ -682,7 +686,7 @@ def run_bias( experiment_config (dict[str, str]): Experiment management configuration. Dictionary contains three optional keys: 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. - """ + """ # noqa E501 analysis_config = data_config.get_config() analysis_config.update(bias_config.get_config()) analysis_config["predictor"] = model_config.get_predictor_config() diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index f9cbc14c1e..32fb9d8480 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -68,6 +68,15 @@ def test_data_config(): assert "FullyReplicated" == data_config.s3_data_distribution_type +def test_invalid_data_config(): + with pytest.raises(ValueError, match=r"^Invalid dataset_type"): + DataConfig( + s3_data_input_path="s3://bucket/inputpath", + s3_output_path="s3://bucket/outputpath", + dataset_type="whatnot_type", + ) + + def test_data_bias_config(): label_values = [1] facet_name = "F1"