Skip to content

Commit 9533b14

Browse files
authored
Merge branch 'master' into deprecate_lambda_model
2 parents 788a199 + e5e0408 commit 9533b14

File tree

5 files changed

+74
-16
lines changed

5 files changed

+74
-16
lines changed

src/sagemaker/clarify.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,13 @@ class SHAPConfig(ExplainabilityConfig):
305305

306306
def __init__(
307307
self,
308-
baseline,
309-
num_samples,
310-
agg_method,
308+
baseline=None,
309+
num_samples=None,
310+
agg_method=None,
311311
use_logit=False,
312312
save_local_shap_values=True,
313313
seed=None,
314+
num_clusters=None,
314315
):
315316
"""Initializes config for SHAP.
316317
@@ -320,34 +321,49 @@ def __init__(
320321
be the same as the dataset format. Each row should contain only the feature
321322
columns/values and omit the label column/values. If None a baseline will be
322323
calculated automatically by using K-means or K-prototypes in the input dataset.
323-
num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
324+
num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
324325
This number determines the size of the generated synthetic dataset to compute the
325-
SHAP values.
326-
agg_method (str): Aggregation method for global SHAP values. Valid values are
326+
SHAP values. If not provided then Clarify job will choose a proper value according
327+
to the count of features.
328+
agg_method (None or str): Aggregation method for global SHAP values. Valid values are
327329
"mean_abs" (mean of absolute SHAP values for all instances),
328330
"median" (median of SHAP values for all instances) and
329331
"mean_sq" (mean of squared SHAP values for all instances).
332+
If not provided then Clarify job uses method "mean_abs"
330333
use_logit (bool): Indicator of whether the logit function is to be applied to the model
331334
predictions. Default is False. If "use_logit" is true then the SHAP values will
332335
have log-odds units.
333336
save_local_shap_values (bool): Indicator of whether to save the local SHAP values
334337
in the output location. Default is True.
335338
seed (int): seed value to get deterministic SHAP values. Default is None.
339+
num_clusters (None or int): If a baseline is not provided, Clarify automatically
340+
computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
341+
num_clusters is a parameter for this algorithm. num_clusters will be the resulting
342+
size of the baseline dataset. If not provided, Clarify job will use a default value.
336343
"""
337-
if agg_method not in ["mean_abs", "median", "mean_sq"]:
344+
if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]:
338345
raise ValueError(
339346
f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq."
340347
)
341-
348+
if num_clusters is not None and baseline is not None:
349+
raise ValueError(
350+
"Baseline and num_clusters cannot be provided together. "
351+
"Please specify one of the two."
352+
)
342353
self.shap_config = {
343-
"baseline": baseline,
344-
"num_samples": num_samples,
345-
"agg_method": agg_method,
346354
"use_logit": use_logit,
347355
"save_local_shap_values": save_local_shap_values,
348356
}
357+
if baseline is not None:
358+
self.shap_config["baseline"] = baseline
359+
if num_samples is not None:
360+
self.shap_config["num_samples"] = num_samples
361+
if agg_method is not None:
362+
self.shap_config["agg_method"] = agg_method
349363
if seed is not None:
350364
self.shap_config["seed"] = seed
365+
if num_clusters is not None:
366+
self.shap_config["num_clusters"] = num_clusters
351367

352368
def get_explainability_config(self):
353369
"""Returns config."""

tests/integ/test_clarify_model_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
SHAP_NUM_OF_SAMPLES = 5
6565
SHAP_AGG_METHOD = "mean_abs"
6666

67-
CRON = "cron(*/5 * * * ? *)"
67+
CRON = "cron(0 * * * ? *)"
6868
UPDATED_CRON = CronExpressionGenerator.daily()
6969
MAX_RUNTIME_IN_SECONDS = 30 * 60
7070
UPDATED_MAX_RUNTIME_IN_SECONDS = 25 * 60

tests/integ/test_model_monitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181

8282
INTEG_TEST_MONITORING_OUTPUT_BUCKET = "integ-test-monitoring-output-bucket"
8383

84-
FIVE_MIN_CRON_EXPRESSION = "cron(0/5 * ? * * *)"
84+
HOURLY_CRON_EXPRESSION = "cron(0 * ? * * *)"
8585

8686

8787
@pytest.fixture(scope="module")
@@ -151,7 +151,7 @@ def default_monitoring_schedule_name(sagemaker_session, output_kms_key, volume_k
151151
output_s3_uri=output_s3_uri,
152152
statistics=statistics,
153153
constraints=constraints,
154-
schedule_cron_expression=FIVE_MIN_CRON_EXPRESSION,
154+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
155155
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
156156
)
157157

@@ -211,7 +211,7 @@ def byoc_monitoring_schedule_name(sagemaker_session, output_kms_key, volume_kms_
211211
output=MonitoringOutput(source="/opt/ml/processing/output", destination=output_s3_uri),
212212
statistics=statistics,
213213
constraints=constraints,
214-
schedule_cron_expression=FIVE_MIN_CRON_EXPRESSION,
214+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
215215
)
216216

217217
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)

tests/integ/test_model_quality_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
HEADERS_OF_FEATURES = ["F1", "F2", "F3", "F4", "F5", "F6", "F7"]
5050
ALL_HEADERS = [*HEADERS_OF_FEATURES, HEADER_OF_LABEL, HEADER_OF_PREDICTED_LABEL]
5151

52-
CRON = "cron(*/5 * * * ? *)"
52+
CRON = "cron(0 * * * ? *)"
5353
UPDATED_CRON = CronExpressionGenerator.daily()
5454
MAX_RUNTIME_IN_SECONDS = 30 * 60
5555
UPDATED_MAX_RUNTIME_IN_SECONDS = 25 * 60

tests/unit/test_clarify.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,42 @@ def test_shap_config():
268268
assert expected_config == shap_config.get_explainability_config()
269269

270270

271+
def test_shap_config_no_baseline():
272+
num_samples = 100
273+
agg_method = "mean_sq"
274+
use_logit = True
275+
seed = 123
276+
shap_config = SHAPConfig(
277+
num_samples=num_samples,
278+
agg_method=agg_method,
279+
num_clusters=2,
280+
use_logit=use_logit,
281+
seed=seed,
282+
)
283+
expected_config = {
284+
"shap": {
285+
"num_samples": num_samples,
286+
"agg_method": agg_method,
287+
"num_clusters": 2,
288+
"use_logit": use_logit,
289+
"save_local_shap_values": True,
290+
"seed": seed,
291+
}
292+
}
293+
assert expected_config == shap_config.get_explainability_config()
294+
295+
296+
def test_shap_config_no_parameters():
297+
shap_config = SHAPConfig()
298+
expected_config = {
299+
"shap": {
300+
"use_logit": False,
301+
"save_local_shap_values": True,
302+
}
303+
}
304+
assert expected_config == shap_config.get_explainability_config()
305+
306+
271307
def test_invalid_shap_config():
272308
with pytest.raises(ValueError) as error:
273309
SHAPConfig(
@@ -278,6 +314,12 @@ def test_invalid_shap_config():
278314
assert "Invalid agg_method invalid. Please choose mean_abs, median, or mean_sq." in str(
279315
error.value
280316
)
317+
with pytest.raises(ValueError) as error:
318+
SHAPConfig(baseline=[[1]], num_samples=1, agg_method="mean_abs", num_clusters=2)
319+
assert (
320+
"Baseline and num_clusters cannot be provided together. Please specify one of the two."
321+
in str(error.value)
322+
)
281323

282324

283325
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)