|
27 | 27 | from abc import ABC, abstractmethod
|
28 | 28 | from typing import List, Union, Dict, Optional, Any
|
29 | 29 |
|
| 30 | +from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex |
| 31 | + |
30 | 32 | from sagemaker import image_uris, s3, utils
|
31 | 33 | from sagemaker.session import Session
|
32 | 34 | from sagemaker.network import NetworkConfig
|
|
35 | 37 | logger = logging.getLogger(__name__)
|
36 | 38 |
|
37 | 39 |
|
| 40 | +ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" |
| 41 | + |
| 42 | +ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema( |
| 43 | + { |
| 44 | + SchemaOptional("version"): str, |
| 45 | + "dataset_type": And( |
| 46 | + str, |
| 47 | + Use(str.lower), |
| 48 | + lambda s: s |
| 49 | + in ( |
| 50 | + "text/csv", |
| 51 | + "application/jsonlines", |
| 52 | + "application/sagemakercapturejson", |
| 53 | + "application/x-parquet", |
| 54 | + "application/x-image", |
| 55 | + ), |
| 56 | + ), |
| 57 | + SchemaOptional("dataset_uri"): str, |
| 58 | + SchemaOptional("headers"): [str], |
| 59 | + SchemaOptional("label"): Or(str, int), |
| 60 | + # this field indicates user provides predicted_label in dataset |
| 61 | + SchemaOptional("predicted_label"): Or(str, int), |
| 62 | + SchemaOptional("features"): str, |
| 63 | + SchemaOptional("label_values_or_threshold"): [Or(int, float, str)], |
| 64 | + SchemaOptional("probability_threshold"): float, |
| 65 | + SchemaOptional("facet"): [ |
| 66 | + { |
| 67 | + "name_or_index": Or(str, int), |
| 68 | + SchemaOptional("value_or_threshold"): [Or(int, float, str)], |
| 69 | + } |
| 70 | + ], |
| 71 | + SchemaOptional("facet_dataset_uri"): str, |
| 72 | + SchemaOptional("facet_headers"): [str], |
| 73 | + SchemaOptional("predicted_label_dataset_uri"): str, |
| 74 | + SchemaOptional("predicted_label_headers"): [str], |
| 75 | + SchemaOptional("excluded_columns"): [Or(int, str)], |
| 76 | + SchemaOptional("joinsource_name_or_index"): Or(str, int), |
| 77 | + SchemaOptional("group_variable"): Or(str, int), |
| 78 | + "methods": { |
| 79 | + SchemaOptional("shap"): { |
| 80 | + SchemaOptional("baseline"): Or( |
| 81 | + # URI of the baseline data file |
| 82 | + str, |
| 83 | + # Inplace baseline data (a list of something) |
| 84 | + [ |
| 85 | + Or( |
| 86 | + # CSV row |
| 87 | + [Or(int, float, str, None)], |
| 88 | + # JSON row (any JSON object). As I write this only |
| 89 | + # SageMaker JSONLines Dense Format ([1]) |
| 90 | + # is supported and the validation is NOT done |
| 91 | + # by the schema but by the data loader. |
| 92 | + # [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines |
| 93 | + {object: object}, |
| 94 | + ) |
| 95 | + ], |
| 96 | + ), |
| 97 | + SchemaOptional("num_clusters"): int, |
| 98 | + SchemaOptional("use_logit"): bool, |
| 99 | + SchemaOptional("num_samples"): int, |
| 100 | + SchemaOptional("agg_method"): And( |
| 101 | + str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq") |
| 102 | + ), |
| 103 | + SchemaOptional("save_local_shap_values"): bool, |
| 104 | + SchemaOptional("text_config"): { |
| 105 | + "granularity": And( |
| 106 | + str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph") |
| 107 | + ), |
| 108 | + "language": And( |
| 109 | + str, |
| 110 | + Use(str.lower), |
| 111 | + lambda s: s |
| 112 | + in ( |
| 113 | + "chinese", |
| 114 | + "zh", |
| 115 | + "danish", |
| 116 | + "da", |
| 117 | + "dutch", |
| 118 | + "nl", |
| 119 | + "english", |
| 120 | + "en", |
| 121 | + "french", |
| 122 | + "fr", |
| 123 | + "german", |
| 124 | + "de", |
| 125 | + "greek", |
| 126 | + "el", |
| 127 | + "italian", |
| 128 | + "it", |
| 129 | + "japanese", |
| 130 | + "ja", |
| 131 | + "lithuanian", |
| 132 | + "lt", |
| 133 | + "multi-language", |
| 134 | + "xx", |
| 135 | + "norwegian bokmål", |
| 136 | + "nb", |
| 137 | + "polish", |
| 138 | + "pl", |
| 139 | + "portuguese", |
| 140 | + "pt", |
| 141 | + "romanian", |
| 142 | + "ro", |
| 143 | + "russian", |
| 144 | + "ru", |
| 145 | + "spanish", |
| 146 | + "es", |
| 147 | + "afrikaans", |
| 148 | + "af", |
| 149 | + "albanian", |
| 150 | + "sq", |
| 151 | + "arabic", |
| 152 | + "ar", |
| 153 | + "armenian", |
| 154 | + "hy", |
| 155 | + "basque", |
| 156 | + "eu", |
| 157 | + "bengali", |
| 158 | + "bn", |
| 159 | + "bulgarian", |
| 160 | + "bg", |
| 161 | + "catalan", |
| 162 | + "ca", |
| 163 | + "croatian", |
| 164 | + "hr", |
| 165 | + "czech", |
| 166 | + "cs", |
| 167 | + "estonian", |
| 168 | + "et", |
| 169 | + "finnish", |
| 170 | + "fi", |
| 171 | + "gujarati", |
| 172 | + "gu", |
| 173 | + "hebrew", |
| 174 | + "he", |
| 175 | + "hindi", |
| 176 | + "hi", |
| 177 | + "hungarian", |
| 178 | + "hu", |
| 179 | + "icelandic", |
| 180 | + "is", |
| 181 | + "indonesian", |
| 182 | + "id", |
| 183 | + "irish", |
| 184 | + "ga", |
| 185 | + "kannada", |
| 186 | + "kn", |
| 187 | + "kyrgyz", |
| 188 | + "ky", |
| 189 | + "latvian", |
| 190 | + "lv", |
| 191 | + "ligurian", |
| 192 | + "lij", |
| 193 | + "luxembourgish", |
| 194 | + "lb", |
| 195 | + "macedonian", |
| 196 | + "mk", |
| 197 | + "malayalam", |
| 198 | + "ml", |
| 199 | + "marathi", |
| 200 | + "mr", |
| 201 | + "nepali", |
| 202 | + "ne", |
| 203 | + "persian", |
| 204 | + "fa", |
| 205 | + "sanskrit", |
| 206 | + "sa", |
| 207 | + "serbian", |
| 208 | + "sr", |
| 209 | + "setswana", |
| 210 | + "tn", |
| 211 | + "sinhala", |
| 212 | + "si", |
| 213 | + "slovak", |
| 214 | + "sk", |
| 215 | + "slovenian", |
| 216 | + "sl", |
| 217 | + "swedish", |
| 218 | + "sv", |
| 219 | + "tagalog", |
| 220 | + "tl", |
| 221 | + "tamil", |
| 222 | + "ta", |
| 223 | + "tatar", |
| 224 | + "tt", |
| 225 | + "telugu", |
| 226 | + "te", |
| 227 | + "thai", |
| 228 | + "th", |
| 229 | + "turkish", |
| 230 | + "tr", |
| 231 | + "ukrainian", |
| 232 | + "uk", |
| 233 | + "urdu", |
| 234 | + "ur", |
| 235 | + "vietnamese", |
| 236 | + "vi", |
| 237 | + "yoruba", |
| 238 | + "yo", |
| 239 | + ), |
| 240 | + ), |
| 241 | + SchemaOptional("max_top_tokens"): int, |
| 242 | + }, |
| 243 | + SchemaOptional("image_config"): { |
| 244 | + SchemaOptional("num_segments"): int, |
| 245 | + SchemaOptional("segment_compactness"): int, |
| 246 | + SchemaOptional("feature_extraction_method"): str, |
| 247 | + SchemaOptional("model_type"): str, |
| 248 | + SchemaOptional("max_objects"): int, |
| 249 | + SchemaOptional("iou_threshold"): float, |
| 250 | + SchemaOptional("context"): float, |
| 251 | + SchemaOptional("debug"): { |
| 252 | + SchemaOptional("image_names"): [str], |
| 253 | + SchemaOptional("class_ids"): [int], |
| 254 | + SchemaOptional("sample_from"): int, |
| 255 | + SchemaOptional("sample_to"): int, |
| 256 | + }, |
| 257 | + }, |
| 258 | + SchemaOptional("seed"): int, |
| 259 | + }, |
| 260 | + SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])}, |
| 261 | + SchemaOptional("post_training_bias"): {"methods": Or(str, [str])}, |
| 262 | + SchemaOptional("pdp"): { |
| 263 | + "grid_resolution": int, |
| 264 | + SchemaOptional("features"): [Or(str, int)], |
| 265 | + SchemaOptional("top_k_features"): int, |
| 266 | + }, |
| 267 | + SchemaOptional("report"): {"name": str, SchemaOptional("title"): str}, |
| 268 | + }, |
| 269 | + SchemaOptional("predictor"): { |
| 270 | + SchemaOptional("endpoint_name"): str, |
| 271 | + SchemaOptional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)), |
| 272 | + SchemaOptional("model_name"): str, |
| 273 | + SchemaOptional("target_model"): str, |
| 274 | + SchemaOptional("instance_type"): str, |
| 275 | + SchemaOptional("initial_instance_count"): int, |
| 276 | + SchemaOptional("accelerator_type"): str, |
| 277 | + SchemaOptional("content_type"): And( |
| 278 | + str, |
| 279 | + Use(str.lower), |
| 280 | + lambda s: s |
| 281 | + in ( |
| 282 | + "text/csv", |
| 283 | + "application/jsonlines", |
| 284 | + "image/jpeg", |
| 285 | + "image/jpg", |
| 286 | + "image/png", |
| 287 | + "application/x-npy", |
| 288 | + ), |
| 289 | + ), |
| 290 | + SchemaOptional("accept_type"): And( |
| 291 | + str, |
| 292 | + Use(str.lower), |
| 293 | + lambda s: s in ("text/csv", "application/jsonlines", "application/json"), |
| 294 | + ), |
| 295 | + SchemaOptional("label"): Or(str, int), |
| 296 | + SchemaOptional("probability"): Or(str, int), |
| 297 | + SchemaOptional("label_headers"): [Or(str, int)], |
| 298 | + SchemaOptional("content_template"): Or(str, {str: str}), |
| 299 | + SchemaOptional("custom_attributes"): str, |
| 300 | + }, |
| 301 | + } |
| 302 | +) |
| 303 | + |
| 304 | + |
38 | 305 | class DataConfig:
|
39 | 306 | """Config object related to configurations of the input and output dataset."""
|
40 | 307 |
|
@@ -926,6 +1193,7 @@ def __init__(
|
926 | 1193 | network_config: Optional[NetworkConfig] = None,
|
927 | 1194 | job_name_prefix: Optional[str] = None,
|
928 | 1195 | version: Optional[str] = None,
|
| 1196 | + skip_early_validation: bool = False, |
929 | 1197 | ):
|
930 | 1198 | """Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
|
931 | 1199 |
|
@@ -967,10 +1235,12 @@ def __init__(
|
967 | 1235 | inter-container traffic, security group IDs, and subnets.
|
968 | 1236 | job_name_prefix (str): Processing job name prefix.
|
969 | 1237 | version (str): Clarify version to use.
|
| 1238 | + skip_early_validation (bool): To skip schema validation of the generated analysis_schema.json. |
970 | 1239 | """ # noqa E501 # pylint: disable=c0301
|
971 | 1240 | container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
|
972 | 1241 | self._last_analysis_config = None
|
973 | 1242 | self.job_name_prefix = job_name_prefix
|
| 1243 | + self.skip_early_validation = skip_early_validation |
974 | 1244 | super(SageMakerClarifyProcessor, self).__init__(
|
975 | 1245 | role,
|
976 | 1246 | container_uri,
|
@@ -1034,6 +1304,8 @@ def _run(
|
1034 | 1304 | # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
|
1035 | 1305 | self._last_analysis_config = analysis_config
|
1036 | 1306 | logger.info("Analysis Config: %s", analysis_config)
|
| 1307 | + if not self.skip_early_validation: |
| 1308 | + ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config) |
1037 | 1309 |
|
1038 | 1310 | with tempfile.TemporaryDirectory() as tmpdirname:
|
1039 | 1311 | analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
|
|
0 commit comments