diff --git a/setup.py b/setup.py index aff6bb938b..c022e12d6f 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ def read_requirements(filename): "packaging>=20.0", "pandas", "pathos", + "schema", ] # Specific use case dependencies diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 781bae30fb..4765630ce8 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -27,6 +27,8 @@ from abc import ABC, abstractmethod from typing import List, Union, Dict, Optional, Any +from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex + from sagemaker import image_uris, s3, utils from sagemaker.session import Session from sagemaker.network import NetworkConfig @@ -35,6 +37,271 @@ logger = logging.getLogger(__name__) +ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" + +ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema( + { + SchemaOptional("version"): str, + "dataset_type": And( + str, + Use(str.lower), + lambda s: s + in ( + "text/csv", + "application/jsonlines", + "application/sagemakercapturejson", + "application/x-parquet", + "application/x-image", + ), + ), + SchemaOptional("dataset_uri"): str, + SchemaOptional("headers"): [str], + SchemaOptional("label"): Or(str, int), + # this field indicates user provides predicted_label in dataset + SchemaOptional("predicted_label"): Or(str, int), + SchemaOptional("features"): str, + SchemaOptional("label_values_or_threshold"): [Or(int, float, str)], + SchemaOptional("probability_threshold"): float, + SchemaOptional("facet"): [ + { + "name_or_index": Or(str, int), + SchemaOptional("value_or_threshold"): [Or(int, float, str)], + } + ], + SchemaOptional("facet_dataset_uri"): str, + SchemaOptional("facet_headers"): [str], + SchemaOptional("predicted_label_dataset_uri"): str, + SchemaOptional("predicted_label_headers"): [str], + SchemaOptional("excluded_columns"): [Or(int, str)], + SchemaOptional("joinsource_name_or_index"): Or(str, int), + SchemaOptional("group_variable"): Or(str, int), + "methods": { + SchemaOptional("shap"): { + SchemaOptional("baseline"): Or( + # URI of the baseline data file + str, + # Inplace baseline data (a list of something) + [ + Or( + # CSV row + [Or(int, float, str, None)], + # JSON row (any JSON object). As I write this only + # SageMaker JSONLines Dense Format ([1]) + # is supported and the validation is NOT done + # by the schema but by the data loader. + # [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines + {object: object}, + ) + ], + ), + SchemaOptional("num_clusters"): int, + SchemaOptional("use_logit"): bool, + SchemaOptional("num_samples"): int, + SchemaOptional("agg_method"): And( + str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq") + ), + SchemaOptional("save_local_shap_values"): bool, + SchemaOptional("text_config"): { + "granularity": And( + str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph") + ), + "language": And( + str, + Use(str.lower), + lambda s: s + in ( + "chinese", + "zh", + "danish", + "da", + "dutch", + "nl", + "english", + "en", + "french", + "fr", + "german", + "de", + "greek", + "el", + "italian", + "it", + "japanese", + "ja", + "lithuanian", + "lt", + "multi-language", + "xx", + "norwegian bokmål", + "nb", + "polish", + "pl", + "portuguese", + "pt", + "romanian", + "ro", + "russian", + "ru", + "spanish", + "es", + "afrikaans", + "af", + "albanian", + "sq", + "arabic", + "ar", + "armenian", + "hy", + "basque", + "eu", + "bengali", + "bn", + "bulgarian", + "bg", + "catalan", + "ca", + "croatian", + "hr", + "czech", + "cs", + "estonian", + "et", + "finnish", + "fi", + "gujarati", + "gu", + "hebrew", + "he", + "hindi", + "hi", + "hungarian", + "hu", + "icelandic", + "is", + "indonesian", + "id", + "irish", + "ga", + "kannada", + "kn", + "kyrgyz", + "ky", + "latvian", + "lv", + "ligurian", + "lij", + "luxembourgish", + "lb", + "macedonian", + "mk", + "malayalam", + "ml", + "marathi", + "mr", + "nepali", + "ne", + "persian", + "fa", + "sanskrit", + "sa", + "serbian", + "sr", + "setswana", + "tn", + "sinhala", + "si", + "slovak", + "sk", + "slovenian", + "sl", + "swedish", + "sv", + "tagalog", + "tl", + "tamil", + "ta", + "tatar", + "tt", + "telugu", + "te", + "thai", + "th", + "turkish", + "tr", + "ukrainian", + "uk", + "urdu", + "ur", + "vietnamese", + "vi", + "yoruba", + "yo", + ), + ), + SchemaOptional("max_top_tokens"): int, + }, + SchemaOptional("image_config"): { + SchemaOptional("num_segments"): int, + SchemaOptional("segment_compactness"): int, + SchemaOptional("feature_extraction_method"): str, + SchemaOptional("model_type"): str, + SchemaOptional("max_objects"): int, + SchemaOptional("iou_threshold"): float, + SchemaOptional("context"): float, + SchemaOptional("debug"): { + SchemaOptional("image_names"): [str], + SchemaOptional("class_ids"): [int], + SchemaOptional("sample_from"): int, + SchemaOptional("sample_to"): int, + }, + }, + SchemaOptional("seed"): int, + }, + SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])}, + SchemaOptional("post_training_bias"): {"methods": Or(str, [str])}, + SchemaOptional("pdp"): { + "grid_resolution": int, + SchemaOptional("features"): [Or(str, int)], + SchemaOptional("top_k_features"): int, + }, + SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, + }, + SchemaOptional("predictor"): { + SchemaOptional("endpoint_name"): str, + SchemaOptional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)), + SchemaOptional("model_name"): str, + SchemaOptional("target_model"): str, + SchemaOptional("instance_type"): str, + SchemaOptional("initial_instance_count"): int, + SchemaOptional("accelerator_type"): str, + SchemaOptional("content_type"): And( + str, + Use(str.lower), + lambda s: s + in ( + "text/csv", + "application/jsonlines", + "image/jpeg", + "image/jpg", + "image/png", + "application/x-npy", + ), + ), + SchemaOptional("accept_type"): And( + str, + Use(str.lower), + lambda s: s in ("text/csv", "application/jsonlines", "application/json"), + ), + SchemaOptional("label"): Or(str, int), + SchemaOptional("probability"): Or(str, int), + SchemaOptional("label_headers"): [Or(str, int)], + SchemaOptional("content_template"): Or(str, {str: str}), + SchemaOptional("custom_attributes"): str, + }, + } +) + + class DataConfig: """Config object related to configurations of the input and output dataset.""" @@ -926,6 +1193,7 @@ def __init__( network_config: Optional[NetworkConfig] = None, job_name_prefix: Optional[str] = None, version: Optional[str] = None, + skip_early_validation: bool = False, ): """Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations. @@ -967,10 +1235,12 @@ def __init__( inter-container traffic, security group IDs, and subnets. job_name_prefix (str): Processing job name prefix. version (str): Clarify version to use. + skip_early_validation (bool): To skip schema validation of the generated analysis_schema.json. """ # noqa E501 # pylint: disable=c0301 container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) self._last_analysis_config = None self.job_name_prefix = job_name_prefix + self.skip_early_validation = skip_early_validation super(SageMakerClarifyProcessor, self).__init__( role, container_uri, @@ -1034,6 +1304,8 @@ def _run( # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket self._last_analysis_config = analysis_config logger.info("Analysis Config: %s", analysis_config) + if not self.skip_early_validation: + ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config) with tempfile.TemporaryDirectory() as tmpdirname: analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index c9400a7be4..de482997ef 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -914,6 +914,7 @@ def test_run_on_s3_analysis_config_file( processor_run, sagemaker_session, clarify_processor, data_config ): analysis_config = { + "dataset_type": "text/csv", "methods": {"post_training_bias": {"methods": "all"}}, } with patch("sagemaker.clarify._upload_analysis_config", return_value=None) as mock_method: