Skip to content

Commit 4db22a0

Browse files
committed
feat: added ANALYSIS_CONFIG_SCHEMA_V1_0 in clarify
1 parent 736f503 commit 4db22a0

File tree

3 files changed

+286
-14
lines changed

3 files changed

+286
-14
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def read_requirements(filename):
5858
"packaging>=20.0",
5959
"pandas",
6060
"pathos",
61+
"schema",
6162
]
6263

6364
# Specific use case dependencies

src/sagemaker/clarify.py

Lines changed: 284 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,278 @@
2727
from abc import ABC, abstractmethod
2828
from typing import List, Union, Dict
2929

30-
from sagemaker import image_uris, s3, utils
30+
from schema import Schema, And, Use, Or, Optional, Regex
31+
32+
from sagemaker import image_uris, s3, utils, Session
33+
from sagemaker.network import NetworkConfig
3134
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3235

3336
logger = logging.getLogger(__name__)
3437

3538

39+
ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"
40+
41+
42+
ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
43+
{
44+
Optional("version"): str,
45+
"dataset_type": And(
46+
str,
47+
Use(str.lower),
48+
lambda s: s
49+
in (
50+
"text/csv",
51+
"application/jsonlines",
52+
"application/sagemakercapturejson",
53+
"application/x-parquet",
54+
"application/x-image",
55+
),
56+
),
57+
Optional("dataset_uri"): str,
58+
Optional("headers"): [str],
59+
Optional("label"): Or(str, int),
60+
# this field indicates user provides predicted_label in dataset
61+
Optional("predicted_label"): Or(str, int),
62+
Optional("features"): str,
63+
Optional("label_values_or_threshold"): [Or(int, float, str)],
64+
Optional("probability_threshold"): float,
65+
Optional("facet"): [
66+
{"name_or_index": Or(str, int), Optional("value_or_threshold"): [Or(int, float, str)]}
67+
],
68+
Optional("facet_dataset_uri"): str,
69+
Optional("facet_headers"): [str],
70+
Optional("predicted_label_dataset_uri"): str,
71+
Optional("predicted_label_headers"): [str],
72+
Optional("excluded_columns"): [Or(int, str)],
73+
Optional("joinsource_name_or_index"): Or(str, int),
74+
Optional("group_variable"): Or(str, int),
75+
"methods": {
76+
Optional("shap"): {
77+
Optional("baseline"): Or(
78+
# URI of the baseline data file
79+
str,
80+
# Inplace baseline data (a list of something)
81+
[
82+
Or(
83+
# CSV row
84+
[Or(int, float, str, None)],
85+
# JSON row (any JSON object). As I write this only
86+
# SageMaker JSONLines Dense Format ([1])
87+
# is supported and the validation is NOT done
88+
# by the schema but by the data loader.
89+
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
90+
{object: object},
91+
)
92+
],
93+
),
94+
Optional("num_clusters"): int,
95+
Optional("use_logit"): bool,
96+
Optional("num_samples"): int,
97+
Optional("agg_method"): And(
98+
str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")
99+
),
100+
Optional("save_local_shap_values"): bool,
101+
Optional("text_config"): {
102+
"granularity": And(
103+
str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")
104+
),
105+
"language": And(
106+
str,
107+
Use(str.lower),
108+
lambda s: s
109+
in (
110+
"chinese",
111+
"zh",
112+
"danish",
113+
"da",
114+
"dutch",
115+
"nl",
116+
"english",
117+
"en",
118+
"french",
119+
"fr",
120+
"german",
121+
"de",
122+
"greek",
123+
"el",
124+
"italian",
125+
"it",
126+
"japanese",
127+
"ja",
128+
"lithuanian",
129+
"lt",
130+
"multi-language",
131+
"xx",
132+
"norwegian bokmål",
133+
"nb",
134+
"polish",
135+
"pl",
136+
"portuguese",
137+
"pt",
138+
"romanian",
139+
"ro",
140+
"russian",
141+
"ru",
142+
"spanish",
143+
"es",
144+
"afrikaans",
145+
"af",
146+
"albanian",
147+
"sq",
148+
"arabic",
149+
"ar",
150+
"armenian",
151+
"hy",
152+
"basque",
153+
"eu",
154+
"bengali",
155+
"bn",
156+
"bulgarian",
157+
"bg",
158+
"catalan",
159+
"ca",
160+
"croatian",
161+
"hr",
162+
"czech",
163+
"cs",
164+
"estonian",
165+
"et",
166+
"finnish",
167+
"fi",
168+
"gujarati",
169+
"gu",
170+
"hebrew",
171+
"he",
172+
"hindi",
173+
"hi",
174+
"hungarian",
175+
"hu",
176+
"icelandic",
177+
"is",
178+
"indonesian",
179+
"id",
180+
"irish",
181+
"ga",
182+
"kannada",
183+
"kn",
184+
"kyrgyz",
185+
"ky",
186+
"latvian",
187+
"lv",
188+
"ligurian",
189+
"lij",
190+
"luxembourgish",
191+
"lb",
192+
"macedonian",
193+
"mk",
194+
"malayalam",
195+
"ml",
196+
"marathi",
197+
"mr",
198+
"nepali",
199+
"ne",
200+
"persian",
201+
"fa",
202+
"sanskrit",
203+
"sa",
204+
"serbian",
205+
"sr",
206+
"setswana",
207+
"tn",
208+
"sinhala",
209+
"si",
210+
"slovak",
211+
"sk",
212+
"slovenian",
213+
"sl",
214+
"swedish",
215+
"sv",
216+
"tagalog",
217+
"tl",
218+
"tamil",
219+
"ta",
220+
"tatar",
221+
"tt",
222+
"telugu",
223+
"te",
224+
"thai",
225+
"th",
226+
"turkish",
227+
"tr",
228+
"ukrainian",
229+
"uk",
230+
"urdu",
231+
"ur",
232+
"vietnamese",
233+
"vi",
234+
"yoruba",
235+
"yo",
236+
),
237+
),
238+
Optional("max_top_tokens"): int,
239+
},
240+
Optional("image_config"): {
241+
Optional("num_segments"): int,
242+
Optional("segment_compactness"): int,
243+
Optional("feature_extraction_method"): str,
244+
Optional("model_type"): str,
245+
Optional("max_objects"): int,
246+
Optional("iou_threshold"): float,
247+
Optional("context"): float,
248+
Optional("debug"): {
249+
Optional("image_names"): [str],
250+
Optional("class_ids"): [int],
251+
Optional("sample_from"): int,
252+
Optional("sample_to"): int,
253+
},
254+
},
255+
Optional("seed"): int,
256+
},
257+
Optional("pre_training_bias"): {"methods": Or(str, [str])},
258+
Optional("post_training_bias"): {"methods": Or(str, [str])},
259+
Optional("pdp"): {
260+
"grid_resolution": int,
261+
Optional("features"): [Or(str, int)],
262+
Optional("top_k_features"): int,
263+
},
264+
Optional("report"): {"name": str, Optional("title"): str},
265+
},
266+
Optional("predictor"): {
267+
Optional("endpoint_name"): str,
268+
Optional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)),
269+
Optional("model_name"): str,
270+
Optional("target_model"): str,
271+
Optional("instance_type"): str,
272+
Optional("initial_instance_count"): int,
273+
Optional("accelerator_type"): str,
274+
Optional("content_type"): And(
275+
str,
276+
Use(str.lower),
277+
lambda s: s
278+
in (
279+
"text/csv",
280+
"application/jsonlines",
281+
"image/jpeg",
282+
"image/jpg",
283+
"image/png",
284+
"application/x-npy",
285+
),
286+
),
287+
Optional("accept_type"): And(
288+
str,
289+
Use(str.lower),
290+
lambda s: s in ("text/csv", "application/jsonlines", "application/json"),
291+
),
292+
Optional("label"): Or(str, int),
293+
Optional("probability"): Or(str, int),
294+
Optional("label_headers"): [Or(str, int)],
295+
Optional("content_template"): Or(str, {str: str}),
296+
Optional("custom_attributes"): str,
297+
},
298+
}
299+
)
300+
301+
36302
class DataConfig:
37303
"""Config object related to configurations of the input and output dataset."""
38304

@@ -909,19 +1175,20 @@ class SageMakerClarifyProcessor(Processor):
9091175

9101176
def __init__(
9111177
self,
912-
role,
913-
instance_count,
914-
instance_type,
915-
volume_size_in_gb=30,
916-
volume_kms_key=None,
917-
output_kms_key=None,
918-
max_runtime_in_seconds=None,
919-
sagemaker_session=None,
920-
env=None,
921-
tags=None,
922-
network_config=None,
923-
job_name_prefix=None,
924-
version=None,
1178+
role: str,
1179+
instance_count: int,
1180+
instance_type: str,
1181+
volume_size_in_gb: int = 30,
1182+
volume_kms_key: str = None,
1183+
output_kms_key: str = None,
1184+
max_runtime_in_seconds: int = None,
1185+
sagemaker_session: Session = None,
1186+
env: Dict[str, str] = None,
1187+
tags: List[Dict[str, str]] = None,
1188+
network_config: NetworkConfig = None,
1189+
job_name_prefix: str = None,
1190+
version: str = None,
1191+
skip_early_validation: bool = False,
9251192
):
9261193
"""Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
9271194
@@ -967,6 +1234,7 @@ def __init__(
9671234
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
9681235
self._last_analysis_config = None
9691236
self.job_name_prefix = job_name_prefix
1237+
self.skip_early_validation = skip_early_validation
9701238
super(SageMakerClarifyProcessor, self).__init__(
9711239
role,
9721240
container_uri,
@@ -1030,6 +1298,8 @@ def _run(
10301298
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
10311299
self._last_analysis_config = analysis_config
10321300
logger.info("Analysis Config: %s", analysis_config)
1301+
if not self.skip_early_validation:
1302+
ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config)
10331303

10341304
with tempfile.TemporaryDirectory() as tmpdirname:
10351305
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")

tests/unit/test_clarify.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def test_run_on_s3_analysis_config_file(
914914
processor_run, sagemaker_session, clarify_processor, data_config
915915
):
916916
analysis_config = {
917+
"dataset_type": "text/csv",
917918
"methods": {"post_training_bias": {"methods": "all"}},
918919
}
919920
with patch("sagemaker.clarify._upload_analysis_config", return_value=None) as mock_method:

0 commit comments

Comments
 (0)