Skip to content

Commit 37050d5

Browse files
committed
feat: added ANALYSIS_CONFIG_SCHEMA_V1_0 in clarify
1 parent d9559c7 commit 37050d5

File tree

3 files changed

+272
-14
lines changed

3 files changed

+272
-14
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def read_requirements(filename):
5858
"packaging>=20.0",
5959
"pandas",
6060
"pathos",
61+
"schema",
6162
]
6263

6364
# Specific use case dependencies

src/sagemaker/clarify.py

Lines changed: 270 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,264 @@
2727
from abc import ABC, abstractmethod
2828
from typing import List, Union, Dict
2929

30-
from sagemaker import image_uris, s3, utils
30+
from schema import Schema, And, Use, Or, Optional, Regex
31+
32+
from sagemaker import image_uris, s3, utils, Session
33+
from sagemaker.network import NetworkConfig
3134
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3235

3336
logger = logging.getLogger(__name__)
3437

3538

39+
ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"
40+
41+
42+
ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
43+
{
44+
Optional("version"): str,
45+
"dataset_type": And(
46+
str,
47+
Use(str.lower),
48+
lambda s: s
49+
in (
50+
"text/csv",
51+
"application/jsonlines",
52+
"application/sagemakercapturejson",
53+
"application/x-parquet",
54+
"application/x-image",
55+
),
56+
),
57+
Optional("dataset_uri"): str,
58+
Optional("headers"): [str],
59+
Optional("label"): Or(str, int),
60+
# this field indicates user provides predicted_label in dataset
61+
Optional("predicted_label"): Or(str, int),
62+
Optional("features"): str,
63+
Optional("label_values_or_threshold"): [Or(int, float, str)],
64+
Optional("probability_threshold"): float,
65+
Optional("facet"): [{"name_or_index": Or(str, int), Optional("value_or_threshold"): [Or(int, float, str)]}],
66+
Optional("facet_dataset_uri"): str,
67+
Optional("facet_headers"): [str],
68+
Optional("predicted_label_dataset_uri"): str,
69+
Optional("predicted_label_headers"): [str],
70+
Optional("excluded_columns"): [Or(int, str)],
71+
Optional("joinsource_name_or_index"): Or(str, int),
72+
Optional("group_variable"): Or(str, int),
73+
"methods": {
74+
Optional("shap"): {
75+
Optional("baseline"): Or(
76+
# URI of the baseline data file
77+
str,
78+
# Inplace baseline data (a list of something)
79+
[
80+
Or(
81+
# CSV row
82+
[Or(int, float, str, None)],
83+
# JSON row (any JSON object). As I write this only SageMaker JSONLines Dense Format ([1])
84+
# is supported and the validation is NOT done by the schema but by the data loader.
85+
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
86+
{object: object},
87+
)
88+
],
89+
),
90+
Optional("num_clusters"): int,
91+
Optional("use_logit"): bool,
92+
Optional("num_samples"): int,
93+
Optional("agg_method"): And(str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")),
94+
Optional("save_local_shap_values"): bool,
95+
Optional("text_config"): {
96+
"granularity": And(str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")),
97+
"language": And(
98+
str,
99+
Use(str.lower),
100+
lambda s: s
101+
in (
102+
"chinese",
103+
"zh",
104+
"danish",
105+
"da",
106+
"dutch",
107+
"nl",
108+
"english",
109+
"en",
110+
"french",
111+
"fr",
112+
"german",
113+
"de",
114+
"greek",
115+
"el",
116+
"italian",
117+
"it",
118+
"japanese",
119+
"ja",
120+
"lithuanian",
121+
"lt",
122+
"multi-language",
123+
"xx",
124+
"norwegian bokmål",
125+
"nb",
126+
"polish",
127+
"pl",
128+
"portuguese",
129+
"pt",
130+
"romanian",
131+
"ro",
132+
"russian",
133+
"ru",
134+
"spanish",
135+
"es",
136+
"afrikaans",
137+
"af",
138+
"albanian",
139+
"sq",
140+
"arabic",
141+
"ar",
142+
"armenian",
143+
"hy",
144+
"basque",
145+
"eu",
146+
"bengali",
147+
"bn",
148+
"bulgarian",
149+
"bg",
150+
"catalan",
151+
"ca",
152+
"croatian",
153+
"hr",
154+
"czech",
155+
"cs",
156+
"estonian",
157+
"et",
158+
"finnish",
159+
"fi",
160+
"gujarati",
161+
"gu",
162+
"hebrew",
163+
"he",
164+
"hindi",
165+
"hi",
166+
"hungarian",
167+
"hu",
168+
"icelandic",
169+
"is",
170+
"indonesian",
171+
"id",
172+
"irish",
173+
"ga",
174+
"kannada",
175+
"kn",
176+
"kyrgyz",
177+
"ky",
178+
"latvian",
179+
"lv",
180+
"ligurian",
181+
"lij",
182+
"luxembourgish",
183+
"lb",
184+
"macedonian",
185+
"mk",
186+
"malayalam",
187+
"ml",
188+
"marathi",
189+
"mr",
190+
"nepali",
191+
"ne",
192+
"persian",
193+
"fa",
194+
"sanskrit",
195+
"sa",
196+
"serbian",
197+
"sr",
198+
"setswana",
199+
"tn",
200+
"sinhala",
201+
"si",
202+
"slovak",
203+
"sk",
204+
"slovenian",
205+
"sl",
206+
"swedish",
207+
"sv",
208+
"tagalog",
209+
"tl",
210+
"tamil",
211+
"ta",
212+
"tatar",
213+
"tt",
214+
"telugu",
215+
"te",
216+
"thai",
217+
"th",
218+
"turkish",
219+
"tr",
220+
"ukrainian",
221+
"uk",
222+
"urdu",
223+
"ur",
224+
"vietnamese",
225+
"vi",
226+
"yoruba",
227+
"yo",
228+
),
229+
),
230+
Optional("max_top_tokens"): int,
231+
},
232+
Optional("image_config"): {
233+
Optional("num_segments"): int,
234+
Optional("segment_compactness"): int,
235+
Optional("feature_extraction_method"): str,
236+
Optional("model_type"): str,
237+
Optional("max_objects"): int,
238+
Optional("iou_threshold"): float,
239+
Optional("context"): float,
240+
Optional("debug"): {
241+
Optional("image_names"): [str],
242+
Optional("class_ids"): [int],
243+
Optional("sample_from"): int,
244+
Optional("sample_to"): int,
245+
},
246+
},
247+
Optional("seed"): int,
248+
},
249+
Optional("pre_training_bias"): {"methods": Or(str, [str])},
250+
Optional("post_training_bias"): {"methods": Or(str, [str])},
251+
Optional("pdp"): {
252+
"grid_resolution": int,
253+
Optional("features"): [Or(str, int)],
254+
Optional("top_k_features"): int,
255+
},
256+
Optional("report"): {"name": str, Optional("title"): str},
257+
},
258+
Optional("predictor"): {
259+
Optional("endpoint_name"): str,
260+
Optional("endpoint_name_prefix"): And(
261+
str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)
262+
),
263+
Optional("model_name"): str,
264+
Optional("target_model"): str,
265+
Optional("instance_type"): str,
266+
Optional("initial_instance_count"): int,
267+
Optional("accelerator_type"): str,
268+
Optional("content_type"): And(
269+
str,
270+
Use(str.lower),
271+
lambda s: s
272+
in ("text/csv", "application/jsonlines", "image/jpeg", "image/jpg", "image/png", "application/x-npy"),
273+
),
274+
Optional("accept_type"): And(
275+
str, Use(str.lower), lambda s: s in ("text/csv", "application/jsonlines", "application/json")
276+
),
277+
Optional("label"): Or(str, int),
278+
Optional("probability"): Or(str, int),
279+
Optional("label_headers"): [Or(str, int)],
280+
Optional("content_template"): Or(str, {str: str}),
281+
Optional("custom_attributes"): str,
282+
},
283+
Optional("local_predictor"): {"python_module": str, "class": str, Optional("args"): [str]},
284+
}
285+
)
286+
287+
36288
class DataConfig:
37289
"""Config object related to configurations of the input and output dataset."""
38290

@@ -909,19 +1161,20 @@ class SageMakerClarifyProcessor(Processor):
9091161

9101162
def __init__(
9111163
self,
912-
role,
913-
instance_count,
914-
instance_type,
915-
volume_size_in_gb=30,
916-
volume_kms_key=None,
917-
output_kms_key=None,
918-
max_runtime_in_seconds=None,
919-
sagemaker_session=None,
920-
env=None,
921-
tags=None,
922-
network_config=None,
923-
job_name_prefix=None,
924-
version=None,
1164+
role: str,
1165+
instance_count: int,
1166+
instance_type: str,
1167+
volume_size_in_gb: int = 30,
1168+
volume_kms_key: str = None,
1169+
output_kms_key: str = None,
1170+
max_runtime_in_seconds: int = None,
1171+
sagemaker_session: Session = None,
1172+
env: Dict[str, str] = None,
1173+
tags: List[Dict[str, str]] = None,
1174+
network_config: NetworkConfig = None,
1175+
job_name_prefix: str = None,
1176+
version: str = None,
1177+
skip_early_validation: bool = False,
9251178
):
9261179
"""Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
9271180
@@ -967,6 +1220,7 @@ def __init__(
9671220
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
9681221
self._last_analysis_config = None
9691222
self.job_name_prefix = job_name_prefix
1223+
self.skip_early_validation = skip_early_validation
9701224
super(SageMakerClarifyProcessor, self).__init__(
9711225
role,
9721226
container_uri,
@@ -1030,6 +1284,8 @@ def _run(
10301284
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
10311285
self._last_analysis_config = analysis_config
10321286
logger.info("Analysis Config: %s", analysis_config)
1287+
if not self.skip_early_validation:
1288+
ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config)
10331289

10341290
with tempfile.TemporaryDirectory() as tmpdirname:
10351291
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")

tests/unit/test_clarify.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def test_run_on_s3_analysis_config_file(
914914
processor_run, sagemaker_session, clarify_processor, data_config
915915
):
916916
analysis_config = {
917+
"dataset_type": "text/csv",
917918
"methods": {"post_training_bias": {"methods": "all"}},
918919
}
919920
with patch("sagemaker.clarify._upload_analysis_config", return_value=None) as mock_method:

0 commit comments

Comments
 (0)