|
27 | 27 | from abc import ABC, abstractmethod
|
28 | 28 | from typing import List, Union, Dict
|
29 | 29 |
|
| 30 | +from schema import Schema, And, Use, Or, Optional, Regex |
| 31 | + |
30 | 32 | from sagemaker import image_uris, s3, utils
|
31 | 33 | from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
|
32 | 34 |
|
33 | 35 | logger = logging.getLogger(__name__)
|
34 | 36 |
|
35 | 37 |
|
| 38 | +ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])" |
| 39 | + |
| 40 | + |
| 41 | +ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema( |
| 42 | + { |
| 43 | + Optional("version"): str, |
| 44 | + "dataset_type": And( |
| 45 | + str, |
| 46 | + Use(str.lower), |
| 47 | + lambda s: s |
| 48 | + in ( |
| 49 | + "text/csv", |
| 50 | + "application/jsonlines", |
| 51 | + "application/sagemakercapturejson", |
| 52 | + "application/x-parquet", |
| 53 | + "application/x-image", |
| 54 | + ), |
| 55 | + ), |
| 56 | + Optional("dataset_uri"): str, |
| 57 | + Optional("headers"): [str], |
| 58 | + Optional("label"): Or(str, int), |
| 59 | + # this field indicates user provides predicted_label in dataset |
| 60 | + Optional("predicted_label"): Or(str, int), |
| 61 | + Optional("features"): str, |
| 62 | + Optional("label_values_or_threshold"): [Or(int, float, str)], |
| 63 | + Optional("probability_threshold"): float, |
| 64 | + Optional("facet"): [{"name_or_index": Or(str, int), Optional("value_or_threshold"): [Or(int, float, str)]}], |
| 65 | + Optional("facet_dataset_uri"): str, |
| 66 | + Optional("facet_headers"): [str], |
| 67 | + Optional("predicted_label_dataset_uri"): str, |
| 68 | + Optional("predicted_label_headers"): [str], |
| 69 | + Optional("excluded_columns"): [Or(int, str)], |
| 70 | + Optional("joinsource_name_or_index"): Or(str, int), |
| 71 | + Optional("group_variable"): Or(str, int), |
| 72 | + "methods": { |
| 73 | + Optional("shap"): { |
| 74 | + Optional("baseline"): Or( |
| 75 | + # URI of the baseline data file |
| 76 | + str, |
| 77 | + # Inplace baseline data (a list of something) |
| 78 | + [ |
| 79 | + Or( |
| 80 | + # CSV row |
| 81 | + [Or(int, float, str, None)], |
| 82 | + # JSON row (any JSON object). As I write this only SageMaker JSONLines Dense Format ([1]) |
| 83 | + # is supported and the validation is NOT done by the schema but by the data loader. |
| 84 | + # [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines |
| 85 | + {object: object}, |
| 86 | + ) |
| 87 | + ], |
| 88 | + ), |
| 89 | + Optional("num_clusters"): int, |
| 90 | + Optional("use_logit"): bool, |
| 91 | + Optional("num_samples"): int, |
| 92 | + Optional("agg_method"): And(str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")), |
| 93 | + Optional("save_local_shap_values"): bool, |
| 94 | + Optional("text_config"): { |
| 95 | + "granularity": And(str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")), |
| 96 | + "language": And( |
| 97 | + str, |
| 98 | + Use(str.lower), |
| 99 | + lambda s: s |
| 100 | + in ( |
| 101 | + "chinese", |
| 102 | + "zh", |
| 103 | + "danish", |
| 104 | + "da", |
| 105 | + "dutch", |
| 106 | + "nl", |
| 107 | + "english", |
| 108 | + "en", |
| 109 | + "french", |
| 110 | + "fr", |
| 111 | + "german", |
| 112 | + "de", |
| 113 | + "greek", |
| 114 | + "el", |
| 115 | + "italian", |
| 116 | + "it", |
| 117 | + "japanese", |
| 118 | + "ja", |
| 119 | + "lithuanian", |
| 120 | + "lt", |
| 121 | + "multi-language", |
| 122 | + "xx", |
| 123 | + "norwegian bokmål", |
| 124 | + "nb", |
| 125 | + "polish", |
| 126 | + "pl", |
| 127 | + "portuguese", |
| 128 | + "pt", |
| 129 | + "romanian", |
| 130 | + "ro", |
| 131 | + "russian", |
| 132 | + "ru", |
| 133 | + "spanish", |
| 134 | + "es", |
| 135 | + "afrikaans", |
| 136 | + "af", |
| 137 | + "albanian", |
| 138 | + "sq", |
| 139 | + "arabic", |
| 140 | + "ar", |
| 141 | + "armenian", |
| 142 | + "hy", |
| 143 | + "basque", |
| 144 | + "eu", |
| 145 | + "bengali", |
| 146 | + "bn", |
| 147 | + "bulgarian", |
| 148 | + "bg", |
| 149 | + "catalan", |
| 150 | + "ca", |
| 151 | + "croatian", |
| 152 | + "hr", |
| 153 | + "czech", |
| 154 | + "cs", |
| 155 | + "estonian", |
| 156 | + "et", |
| 157 | + "finnish", |
| 158 | + "fi", |
| 159 | + "gujarati", |
| 160 | + "gu", |
| 161 | + "hebrew", |
| 162 | + "he", |
| 163 | + "hindi", |
| 164 | + "hi", |
| 165 | + "hungarian", |
| 166 | + "hu", |
| 167 | + "icelandic", |
| 168 | + "is", |
| 169 | + "indonesian", |
| 170 | + "id", |
| 171 | + "irish", |
| 172 | + "ga", |
| 173 | + "kannada", |
| 174 | + "kn", |
| 175 | + "kyrgyz", |
| 176 | + "ky", |
| 177 | + "latvian", |
| 178 | + "lv", |
| 179 | + "ligurian", |
| 180 | + "lij", |
| 181 | + "luxembourgish", |
| 182 | + "lb", |
| 183 | + "macedonian", |
| 184 | + "mk", |
| 185 | + "malayalam", |
| 186 | + "ml", |
| 187 | + "marathi", |
| 188 | + "mr", |
| 189 | + "nepali", |
| 190 | + "ne", |
| 191 | + "persian", |
| 192 | + "fa", |
| 193 | + "sanskrit", |
| 194 | + "sa", |
| 195 | + "serbian", |
| 196 | + "sr", |
| 197 | + "setswana", |
| 198 | + "tn", |
| 199 | + "sinhala", |
| 200 | + "si", |
| 201 | + "slovak", |
| 202 | + "sk", |
| 203 | + "slovenian", |
| 204 | + "sl", |
| 205 | + "swedish", |
| 206 | + "sv", |
| 207 | + "tagalog", |
| 208 | + "tl", |
| 209 | + "tamil", |
| 210 | + "ta", |
| 211 | + "tatar", |
| 212 | + "tt", |
| 213 | + "telugu", |
| 214 | + "te", |
| 215 | + "thai", |
| 216 | + "th", |
| 217 | + "turkish", |
| 218 | + "tr", |
| 219 | + "ukrainian", |
| 220 | + "uk", |
| 221 | + "urdu", |
| 222 | + "ur", |
| 223 | + "vietnamese", |
| 224 | + "vi", |
| 225 | + "yoruba", |
| 226 | + "yo", |
| 227 | + ), |
| 228 | + ), |
| 229 | + Optional("max_top_tokens"): int, |
| 230 | + }, |
| 231 | + Optional("image_config"): { |
| 232 | + Optional("num_segments"): int, |
| 233 | + Optional("segment_compactness"): int, |
| 234 | + Optional("feature_extraction_method"): str, |
| 235 | + Optional("model_type"): str, |
| 236 | + Optional("max_objects"): int, |
| 237 | + Optional("iou_threshold"): float, |
| 238 | + Optional("context"): float, |
| 239 | + Optional("debug"): { |
| 240 | + Optional("image_names"): [str], |
| 241 | + Optional("class_ids"): [int], |
| 242 | + Optional("sample_from"): int, |
| 243 | + Optional("sample_to"): int, |
| 244 | + }, |
| 245 | + }, |
| 246 | + Optional("seed"): int, |
| 247 | + }, |
| 248 | + Optional("pre_training_bias"): {"methods": Or(str, [str])}, |
| 249 | + Optional("post_training_bias"): {"methods": Or(str, [str])}, |
| 250 | + Optional("pdp"): { |
| 251 | + "grid_resolution": int, |
| 252 | + Optional("features"): [Or(str, int)], |
| 253 | + Optional("top_k_features"): int, |
| 254 | + }, |
| 255 | + Optional("report"): {"name": str, Optional("title"): str}, |
| 256 | + }, |
| 257 | + Optional("predictor"): { |
| 258 | + Optional("endpoint_name"): str, |
| 259 | + Optional("endpoint_name_prefix"): And( |
| 260 | + str, Regex(ENDPOINT_NAME_PREFIX_PATTERN) |
| 261 | + ), |
| 262 | + Optional("model_name"): str, |
| 263 | + Optional("target_model"): str, |
| 264 | + Optional("instance_type"): str, |
| 265 | + Optional("initial_instance_count"): int, |
| 266 | + Optional("accelerator_type"): str, |
| 267 | + Optional("content_type"): And( |
| 268 | + str, |
| 269 | + Use(str.lower), |
| 270 | + lambda s: s |
| 271 | + in ("text/csv", "application/jsonlines", "image/jpeg", "image/jpg", "image/png", "application/x-npy"), |
| 272 | + ), |
| 273 | + Optional("accept_type"): And( |
| 274 | + str, Use(str.lower), lambda s: s in ("text/csv", "application/jsonlines", "application/json") |
| 275 | + ), |
| 276 | + Optional("label"): Or(str, int), |
| 277 | + Optional("probability"): Or(str, int), |
| 278 | + Optional("label_headers"): [Or(str, int)], |
| 279 | + Optional("content_template"): Or(str, {str: str}), |
| 280 | + Optional("custom_attributes"): str, |
| 281 | + }, |
| 282 | + Optional("local_predictor"): {"python_module": str, "class": str, Optional("args"): [str]}, |
| 283 | + } |
| 284 | +) |
| 285 | + |
| 286 | + |
36 | 287 | class DataConfig:
|
37 | 288 | """Config object related to configurations of the input and output dataset."""
|
38 | 289 |
|
@@ -1030,6 +1281,7 @@ def _run(
|
1030 | 1281 | # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
|
1031 | 1282 | self._last_analysis_config = analysis_config
|
1032 | 1283 | logger.info("Analysis Config: %s", analysis_config)
|
| 1284 | + ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config) |
1033 | 1285 |
|
1034 | 1286 | with tempfile.TemporaryDirectory() as tmpdirname:
|
1035 | 1287 | analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
|
|
0 commit comments