Skip to content

Commit 2d122b7

Browse files
committed
feat: added ANALYSIS_CONFIG_SCHEMA_V1_0 in clarify
1 parent d9559c7 commit 2d122b7

File tree

3 files changed

+254
-0
lines changed

3 files changed

+254
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def read_requirements(filename):
5858
"packaging>=20.0",
5959
"pandas",
6060
"pathos",
61+
"schema",
6162
]
6263

6364
# Specific use case dependencies

src/sagemaker/clarify.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,263 @@
2727
from abc import ABC, abstractmethod
2828
from typing import List, Union, Dict
2929

30+
from schema import Schema, And, Use, Or, Optional, Regex
31+
3032
from sagemaker import image_uris, s3, utils
3133
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3234

3335
logger = logging.getLogger(__name__)
3436

3537

38+
ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"
39+
40+
41+
ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
42+
{
43+
Optional("version"): str,
44+
"dataset_type": And(
45+
str,
46+
Use(str.lower),
47+
lambda s: s
48+
in (
49+
"text/csv",
50+
"application/jsonlines",
51+
"application/sagemakercapturejson",
52+
"application/x-parquet",
53+
"application/x-image",
54+
),
55+
),
56+
Optional("dataset_uri"): str,
57+
Optional("headers"): [str],
58+
Optional("label"): Or(str, int),
59+
# this field indicates user provides predicted_label in dataset
60+
Optional("predicted_label"): Or(str, int),
61+
Optional("features"): str,
62+
Optional("label_values_or_threshold"): [Or(int, float, str)],
63+
Optional("probability_threshold"): float,
64+
Optional("facet"): [{"name_or_index": Or(str, int), Optional("value_or_threshold"): [Or(int, float, str)]}],
65+
Optional("facet_dataset_uri"): str,
66+
Optional("facet_headers"): [str],
67+
Optional("predicted_label_dataset_uri"): str,
68+
Optional("predicted_label_headers"): [str],
69+
Optional("excluded_columns"): [Or(int, str)],
70+
Optional("joinsource_name_or_index"): Or(str, int),
71+
Optional("group_variable"): Or(str, int),
72+
"methods": {
73+
Optional("shap"): {
74+
Optional("baseline"): Or(
75+
# URI of the baseline data file
76+
str,
77+
# Inplace baseline data (a list of something)
78+
[
79+
Or(
80+
# CSV row
81+
[Or(int, float, str, None)],
82+
# JSON row (any JSON object). As I write this only SageMaker JSONLines Dense Format ([1])
83+
# is supported and the validation is NOT done by the schema but by the data loader.
84+
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
85+
{object: object},
86+
)
87+
],
88+
),
89+
Optional("num_clusters"): int,
90+
Optional("use_logit"): bool,
91+
Optional("num_samples"): int,
92+
Optional("agg_method"): And(str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")),
93+
Optional("save_local_shap_values"): bool,
94+
Optional("text_config"): {
95+
"granularity": And(str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")),
96+
"language": And(
97+
str,
98+
Use(str.lower),
99+
lambda s: s
100+
in (
101+
"chinese",
102+
"zh",
103+
"danish",
104+
"da",
105+
"dutch",
106+
"nl",
107+
"english",
108+
"en",
109+
"french",
110+
"fr",
111+
"german",
112+
"de",
113+
"greek",
114+
"el",
115+
"italian",
116+
"it",
117+
"japanese",
118+
"ja",
119+
"lithuanian",
120+
"lt",
121+
"multi-language",
122+
"xx",
123+
"norwegian bokmål",
124+
"nb",
125+
"polish",
126+
"pl",
127+
"portuguese",
128+
"pt",
129+
"romanian",
130+
"ro",
131+
"russian",
132+
"ru",
133+
"spanish",
134+
"es",
135+
"afrikaans",
136+
"af",
137+
"albanian",
138+
"sq",
139+
"arabic",
140+
"ar",
141+
"armenian",
142+
"hy",
143+
"basque",
144+
"eu",
145+
"bengali",
146+
"bn",
147+
"bulgarian",
148+
"bg",
149+
"catalan",
150+
"ca",
151+
"croatian",
152+
"hr",
153+
"czech",
154+
"cs",
155+
"estonian",
156+
"et",
157+
"finnish",
158+
"fi",
159+
"gujarati",
160+
"gu",
161+
"hebrew",
162+
"he",
163+
"hindi",
164+
"hi",
165+
"hungarian",
166+
"hu",
167+
"icelandic",
168+
"is",
169+
"indonesian",
170+
"id",
171+
"irish",
172+
"ga",
173+
"kannada",
174+
"kn",
175+
"kyrgyz",
176+
"ky",
177+
"latvian",
178+
"lv",
179+
"ligurian",
180+
"lij",
181+
"luxembourgish",
182+
"lb",
183+
"macedonian",
184+
"mk",
185+
"malayalam",
186+
"ml",
187+
"marathi",
188+
"mr",
189+
"nepali",
190+
"ne",
191+
"persian",
192+
"fa",
193+
"sanskrit",
194+
"sa",
195+
"serbian",
196+
"sr",
197+
"setswana",
198+
"tn",
199+
"sinhala",
200+
"si",
201+
"slovak",
202+
"sk",
203+
"slovenian",
204+
"sl",
205+
"swedish",
206+
"sv",
207+
"tagalog",
208+
"tl",
209+
"tamil",
210+
"ta",
211+
"tatar",
212+
"tt",
213+
"telugu",
214+
"te",
215+
"thai",
216+
"th",
217+
"turkish",
218+
"tr",
219+
"ukrainian",
220+
"uk",
221+
"urdu",
222+
"ur",
223+
"vietnamese",
224+
"vi",
225+
"yoruba",
226+
"yo",
227+
),
228+
),
229+
Optional("max_top_tokens"): int,
230+
},
231+
Optional("image_config"): {
232+
Optional("num_segments"): int,
233+
Optional("segment_compactness"): int,
234+
Optional("feature_extraction_method"): str,
235+
Optional("model_type"): str,
236+
Optional("max_objects"): int,
237+
Optional("iou_threshold"): float,
238+
Optional("context"): float,
239+
Optional("debug"): {
240+
Optional("image_names"): [str],
241+
Optional("class_ids"): [int],
242+
Optional("sample_from"): int,
243+
Optional("sample_to"): int,
244+
},
245+
},
246+
Optional("seed"): int,
247+
},
248+
Optional("pre_training_bias"): {"methods": Or(str, [str])},
249+
Optional("post_training_bias"): {"methods": Or(str, [str])},
250+
Optional("pdp"): {
251+
"grid_resolution": int,
252+
Optional("features"): [Or(str, int)],
253+
Optional("top_k_features"): int,
254+
},
255+
Optional("report"): {"name": str, Optional("title"): str},
256+
},
257+
Optional("predictor"): {
258+
Optional("endpoint_name"): str,
259+
Optional("endpoint_name_prefix"): And(
260+
str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)
261+
),
262+
Optional("model_name"): str,
263+
Optional("target_model"): str,
264+
Optional("instance_type"): str,
265+
Optional("initial_instance_count"): int,
266+
Optional("accelerator_type"): str,
267+
Optional("content_type"): And(
268+
str,
269+
Use(str.lower),
270+
lambda s: s
271+
in ("text/csv", "application/jsonlines", "image/jpeg", "image/jpg", "image/png", "application/x-npy"),
272+
),
273+
Optional("accept_type"): And(
274+
str, Use(str.lower), lambda s: s in ("text/csv", "application/jsonlines", "application/json")
275+
),
276+
Optional("label"): Or(str, int),
277+
Optional("probability"): Or(str, int),
278+
Optional("label_headers"): [Or(str, int)],
279+
Optional("content_template"): Or(str, {str: str}),
280+
Optional("custom_attributes"): str,
281+
},
282+
Optional("local_predictor"): {"python_module": str, "class": str, Optional("args"): [str]},
283+
}
284+
)
285+
286+
36287
class DataConfig:
37288
"""Config object related to configurations of the input and output dataset."""
38289

@@ -1030,6 +1281,7 @@ def _run(
10301281
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
10311282
self._last_analysis_config = analysis_config
10321283
logger.info("Analysis Config: %s", analysis_config)
1284+
ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config)
10331285

10341286
with tempfile.TemporaryDirectory() as tmpdirname:
10351287
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")

tests/unit/test_clarify.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def test_run_on_s3_analysis_config_file(
914914
processor_run, sagemaker_session, clarify_processor, data_config
915915
):
916916
analysis_config = {
917+
"dataset_type": "text/csv",
917918
"methods": {"post_training_bias": {"methods": "all"}},
918919
}
919920
with patch("sagemaker.clarify._upload_analysis_config", return_value=None) as mock_method:

0 commit comments

Comments
 (0)