Skip to content

Commit 259d31d

Browse files
committed
Redo test
1 parent b90c856 commit 259d31d

File tree

1 file changed

+66
-9
lines changed

1 file changed

+66
-9
lines changed

tests/unit/test_clarify.py

+66-9
Original file line numberDiff line numberDiff line change
@@ -442,21 +442,22 @@ def test_post_training_bias(
442442
)
443443

444444

445-
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
446-
def test_shap(
445+
def _run_test_shap(
447446
name_from_base,
448447
clarify_processor,
449448
clarify_processor_with_job_name_prefix,
450449
data_config,
451450
model_config,
452451
shap_config,
452+
model_scores,
453+
expected_predictor_config,
453454
):
454455
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
455456
clarify_processor.run_explainability(
456457
data_config,
457458
model_config,
458459
shap_config,
459-
model_scores=None,
460+
model_scores=model_scores,
460461
wait=True,
461462
job_name="test",
462463
experiment_config={"ExperimentName": "AnExperiment"},
@@ -485,11 +486,7 @@ def test_shap(
485486
"save_local_shap_values": True,
486487
}
487488
},
488-
"predictor": {
489-
"model_name": "xgboost-model",
490-
"instance_type": "ml.c5.xlarge",
491-
"initial_instance_count": 1,
492-
},
489+
"predictor": expected_predictor_config,
493490
}
494491
mock_method.assert_called_with(
495492
data_config,
@@ -504,7 +501,7 @@ def test_shap(
504501
data_config,
505502
model_config,
506503
shap_config,
507-
model_scores=None,
504+
model_scores=model_scores,
508505
wait=True,
509506
experiment_config={"ExperimentName": "AnExperiment"},
510507
)
@@ -518,3 +515,63 @@ def test_shap(
518515
None,
519516
{"ExperimentName": "AnExperiment"},
520517
)
518+
519+
520+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
521+
def test_shap(
522+
name_from_base,
523+
clarify_processor,
524+
clarify_processor_with_job_name_prefix,
525+
data_config,
526+
model_config,
527+
shap_config,
528+
):
529+
expected_predictor_config = {
530+
"model_name": "xgboost-model",
531+
"instance_type": "ml.c5.xlarge",
532+
"initial_instance_count": 1,
533+
}
534+
_run_test_shap(
535+
name_from_base,
536+
clarify_processor,
537+
clarify_processor_with_job_name_prefix,
538+
data_config,
539+
model_config,
540+
shap_config,
541+
None,
542+
expected_predictor_config,
543+
)
544+
545+
546+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
547+
def test_shap_with_predicted_label(
548+
name_from_base,
549+
clarify_processor,
550+
clarify_processor_with_job_name_prefix,
551+
data_config,
552+
model_config,
553+
shap_config,
554+
):
555+
probability = "pr"
556+
label_headers = ["success"]
557+
model_scores = ModelPredictedLabelConfig(
558+
probability=probability,
559+
label_headers=label_headers,
560+
)
561+
expected_predictor_config = {
562+
"model_name": "xgboost-model",
563+
"instance_type": "ml.c5.xlarge",
564+
"initial_instance_count": 1,
565+
"probability": probability,
566+
"label_headers": label_headers,
567+
}
568+
_run_test_shap(
569+
name_from_base,
570+
clarify_processor,
571+
clarify_processor_with_job_name_prefix,
572+
data_config,
573+
model_config,
574+
shap_config,
575+
model_scores,
576+
expected_predictor_config,
577+
)

0 commit comments

Comments
 (0)