Skip to content

Commit ebc0c2f

Browse files
authored
Merge branch 'master' into master
2 parents e098545 + b4ea839 commit ebc0c2f

File tree

6 files changed

+282
-1
lines changed

6 files changed

+282
-1
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def read_requirements(filename):
5858
"packaging>=20.0",
5959
"pandas",
6060
"pathos",
61+
"schema",
6162
]
6263

6364
# Specific use case dependencies

src/sagemaker/clarify.py

+272
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from abc import ABC, abstractmethod
2828
from typing import List, Union, Dict, Optional, Any
2929

30+
from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex
31+
3032
from sagemaker import image_uris, s3, utils
3133
from sagemaker.session import Session
3234
from sagemaker.network import NetworkConfig
@@ -35,6 +37,271 @@
3537
logger = logging.getLogger(__name__)
3638

3739

40+
ENDPOINT_NAME_PREFIX_PATTERN = "^[a-zA-Z0-9](-*[a-zA-Z0-9])"
41+
42+
ANALYSIS_CONFIG_SCHEMA_V1_0 = Schema(
43+
{
44+
SchemaOptional("version"): str,
45+
"dataset_type": And(
46+
str,
47+
Use(str.lower),
48+
lambda s: s
49+
in (
50+
"text/csv",
51+
"application/jsonlines",
52+
"application/sagemakercapturejson",
53+
"application/x-parquet",
54+
"application/x-image",
55+
),
56+
),
57+
SchemaOptional("dataset_uri"): str,
58+
SchemaOptional("headers"): [str],
59+
SchemaOptional("label"): Or(str, int),
60+
# this field indicates user provides predicted_label in dataset
61+
SchemaOptional("predicted_label"): Or(str, int),
62+
SchemaOptional("features"): str,
63+
SchemaOptional("label_values_or_threshold"): [Or(int, float, str)],
64+
SchemaOptional("probability_threshold"): float,
65+
SchemaOptional("facet"): [
66+
{
67+
"name_or_index": Or(str, int),
68+
SchemaOptional("value_or_threshold"): [Or(int, float, str)],
69+
}
70+
],
71+
SchemaOptional("facet_dataset_uri"): str,
72+
SchemaOptional("facet_headers"): [str],
73+
SchemaOptional("predicted_label_dataset_uri"): str,
74+
SchemaOptional("predicted_label_headers"): [str],
75+
SchemaOptional("excluded_columns"): [Or(int, str)],
76+
SchemaOptional("joinsource_name_or_index"): Or(str, int),
77+
SchemaOptional("group_variable"): Or(str, int),
78+
"methods": {
79+
SchemaOptional("shap"): {
80+
SchemaOptional("baseline"): Or(
81+
# URI of the baseline data file
82+
str,
83+
# Inplace baseline data (a list of something)
84+
[
85+
Or(
86+
# CSV row
87+
[Or(int, float, str, None)],
88+
# JSON row (any JSON object). As I write this only
89+
# SageMaker JSONLines Dense Format ([1])
90+
# is supported and the validation is NOT done
91+
# by the schema but by the data loader.
92+
# [1] https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#cm-jsonlines
93+
{object: object},
94+
)
95+
],
96+
),
97+
SchemaOptional("num_clusters"): int,
98+
SchemaOptional("use_logit"): bool,
99+
SchemaOptional("num_samples"): int,
100+
SchemaOptional("agg_method"): And(
101+
str, Use(str.lower), lambda s: s in ("mean_abs", "median", "mean_sq")
102+
),
103+
SchemaOptional("save_local_shap_values"): bool,
104+
SchemaOptional("text_config"): {
105+
"granularity": And(
106+
str, Use(str.lower), lambda s: s in ("token", "sentence", "paragraph")
107+
),
108+
"language": And(
109+
str,
110+
Use(str.lower),
111+
lambda s: s
112+
in (
113+
"chinese",
114+
"zh",
115+
"danish",
116+
"da",
117+
"dutch",
118+
"nl",
119+
"english",
120+
"en",
121+
"french",
122+
"fr",
123+
"german",
124+
"de",
125+
"greek",
126+
"el",
127+
"italian",
128+
"it",
129+
"japanese",
130+
"ja",
131+
"lithuanian",
132+
"lt",
133+
"multi-language",
134+
"xx",
135+
"norwegian bokmål",
136+
"nb",
137+
"polish",
138+
"pl",
139+
"portuguese",
140+
"pt",
141+
"romanian",
142+
"ro",
143+
"russian",
144+
"ru",
145+
"spanish",
146+
"es",
147+
"afrikaans",
148+
"af",
149+
"albanian",
150+
"sq",
151+
"arabic",
152+
"ar",
153+
"armenian",
154+
"hy",
155+
"basque",
156+
"eu",
157+
"bengali",
158+
"bn",
159+
"bulgarian",
160+
"bg",
161+
"catalan",
162+
"ca",
163+
"croatian",
164+
"hr",
165+
"czech",
166+
"cs",
167+
"estonian",
168+
"et",
169+
"finnish",
170+
"fi",
171+
"gujarati",
172+
"gu",
173+
"hebrew",
174+
"he",
175+
"hindi",
176+
"hi",
177+
"hungarian",
178+
"hu",
179+
"icelandic",
180+
"is",
181+
"indonesian",
182+
"id",
183+
"irish",
184+
"ga",
185+
"kannada",
186+
"kn",
187+
"kyrgyz",
188+
"ky",
189+
"latvian",
190+
"lv",
191+
"ligurian",
192+
"lij",
193+
"luxembourgish",
194+
"lb",
195+
"macedonian",
196+
"mk",
197+
"malayalam",
198+
"ml",
199+
"marathi",
200+
"mr",
201+
"nepali",
202+
"ne",
203+
"persian",
204+
"fa",
205+
"sanskrit",
206+
"sa",
207+
"serbian",
208+
"sr",
209+
"setswana",
210+
"tn",
211+
"sinhala",
212+
"si",
213+
"slovak",
214+
"sk",
215+
"slovenian",
216+
"sl",
217+
"swedish",
218+
"sv",
219+
"tagalog",
220+
"tl",
221+
"tamil",
222+
"ta",
223+
"tatar",
224+
"tt",
225+
"telugu",
226+
"te",
227+
"thai",
228+
"th",
229+
"turkish",
230+
"tr",
231+
"ukrainian",
232+
"uk",
233+
"urdu",
234+
"ur",
235+
"vietnamese",
236+
"vi",
237+
"yoruba",
238+
"yo",
239+
),
240+
),
241+
SchemaOptional("max_top_tokens"): int,
242+
},
243+
SchemaOptional("image_config"): {
244+
SchemaOptional("num_segments"): int,
245+
SchemaOptional("segment_compactness"): int,
246+
SchemaOptional("feature_extraction_method"): str,
247+
SchemaOptional("model_type"): str,
248+
SchemaOptional("max_objects"): int,
249+
SchemaOptional("iou_threshold"): float,
250+
SchemaOptional("context"): float,
251+
SchemaOptional("debug"): {
252+
SchemaOptional("image_names"): [str],
253+
SchemaOptional("class_ids"): [int],
254+
SchemaOptional("sample_from"): int,
255+
SchemaOptional("sample_to"): int,
256+
},
257+
},
258+
SchemaOptional("seed"): int,
259+
},
260+
SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])},
261+
SchemaOptional("post_training_bias"): {"methods": Or(str, [str])},
262+
SchemaOptional("pdp"): {
263+
"grid_resolution": int,
264+
SchemaOptional("features"): [Or(str, int)],
265+
SchemaOptional("top_k_features"): int,
266+
},
267+
SchemaOptional("report"): {"name": str, SchemaOptional("title"): str},
268+
},
269+
SchemaOptional("predictor"): {
270+
SchemaOptional("endpoint_name"): str,
271+
SchemaOptional("endpoint_name_prefix"): And(str, Regex(ENDPOINT_NAME_PREFIX_PATTERN)),
272+
SchemaOptional("model_name"): str,
273+
SchemaOptional("target_model"): str,
274+
SchemaOptional("instance_type"): str,
275+
SchemaOptional("initial_instance_count"): int,
276+
SchemaOptional("accelerator_type"): str,
277+
SchemaOptional("content_type"): And(
278+
str,
279+
Use(str.lower),
280+
lambda s: s
281+
in (
282+
"text/csv",
283+
"application/jsonlines",
284+
"image/jpeg",
285+
"image/jpg",
286+
"image/png",
287+
"application/x-npy",
288+
),
289+
),
290+
SchemaOptional("accept_type"): And(
291+
str,
292+
Use(str.lower),
293+
lambda s: s in ("text/csv", "application/jsonlines", "application/json"),
294+
),
295+
SchemaOptional("label"): Or(str, int),
296+
SchemaOptional("probability"): Or(str, int),
297+
SchemaOptional("label_headers"): [Or(str, int)],
298+
SchemaOptional("content_template"): Or(str, {str: str}),
299+
SchemaOptional("custom_attributes"): str,
300+
},
301+
}
302+
)
303+
304+
38305
class DataConfig:
39306
"""Config object related to configurations of the input and output dataset."""
40307

@@ -926,6 +1193,7 @@ def __init__(
9261193
network_config: Optional[NetworkConfig] = None,
9271194
job_name_prefix: Optional[str] = None,
9281195
version: Optional[str] = None,
1196+
skip_early_validation: bool = False,
9291197
):
9301198
"""Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
9311199
@@ -967,10 +1235,12 @@ def __init__(
9671235
inter-container traffic, security group IDs, and subnets.
9681236
job_name_prefix (str): Processing job name prefix.
9691237
version (str): Clarify version to use.
1238+
skip_early_validation (bool): To skip schema validation of the generated analysis_schema.json.
9701239
""" # noqa E501 # pylint: disable=c0301
9711240
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
9721241
self._last_analysis_config = None
9731242
self.job_name_prefix = job_name_prefix
1243+
self.skip_early_validation = skip_early_validation
9741244
super(SageMakerClarifyProcessor, self).__init__(
9751245
role,
9761246
container_uri,
@@ -1034,6 +1304,8 @@ def _run(
10341304
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
10351305
self._last_analysis_config = analysis_config
10361306
logger.info("Analysis Config: %s", analysis_config)
1307+
if not self.skip_early_validation:
1308+
ANALYSIS_CONFIG_SCHEMA_V1_0.validate(analysis_config)
10371309

10381310
with tempfile.TemporaryDirectory() as tmpdirname:
10391311
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")

src/sagemaker/image_uri_config/model-monitor.json

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"ap-south-1": "126357580389",
1212
"ap-southeast-1": "245545462676",
1313
"ap-southeast-2": "563025443158",
14+
"ap-southeast-3": "669540362728",
1415
"ca-central-1": "536280801234",
1516
"cn-north-1": "453000072557",
1617
"cn-northwest-1": "453252182341",

src/sagemaker/session.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4999,7 +4999,12 @@ def _rule_statuses_changed(current_statuses, last_statuses):
49994999
def _logs_init(sagemaker_session, description, job):
50005000
"""Placeholder docstring"""
50015001
if job == "Training":
5002-
instance_count = description["ResourceConfig"]["InstanceCount"]
5002+
if description["ResourceConfig"]["InstanceCount"] is not None:
5003+
instance_count = description["ResourceConfig"]["InstanceCount"]
5004+
else:
5005+
instance_count = 0
5006+
for instanceGroup in description["ResourceConfig"]["InstanceGroups"]:
5007+
instance_count += instanceGroup["InstanceCount"]
50035008
elif job == "Transform":
50045009
instance_count = description["TransformResources"]["InstanceCount"]
50055010
elif job == "Processing":

tests/unit/sagemaker/image_uris/test_model_monitor.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"ap-south-1": "126357580389",
2525
"ap-southeast-1": "245545462676",
2626
"ap-southeast-2": "563025443158",
27+
"ap-southeast-3": "669540362728",
2728
"ca-central-1": "536280801234",
2829
"cn-north-1": "453000072557",
2930
"cn-northwest-1": "453252182341",

tests/unit/test_clarify.py

+1
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ def test_run_on_s3_analysis_config_file(
914914
processor_run, sagemaker_session, clarify_processor, data_config
915915
):
916916
analysis_config = {
917+
"dataset_type": "text/csv",
917918
"methods": {"post_training_bias": {"methods": "all"}},
918919
}
919920
with patch("sagemaker.clarify._upload_analysis_config", return_value=None) as mock_method:

0 commit comments

Comments
 (0)