@@ -379,13 +379,20 @@ def test_post_training_bias(
379
379
)
380
380
381
381
382
- def test_shap (clarify_processor , data_config , model_config , shap_config ):
382
+ def _run_test_shap (
383
+ clarify_processor ,
384
+ data_config ,
385
+ model_config ,
386
+ shap_config ,
387
+ model_scores ,
388
+ expected_predictor_config ,
389
+ ):
383
390
with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
384
391
clarify_processor .run_explainability (
385
392
data_config ,
386
393
model_config ,
387
394
shap_config ,
388
- model_scores = None ,
395
+ model_scores = model_scores ,
389
396
wait = True ,
390
397
job_name = "test" ,
391
398
experiment_config = {"ExperimentName" : "AnExperiment" },
@@ -414,11 +421,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
414
421
"save_local_shap_values" : True ,
415
422
}
416
423
},
417
- "predictor" : {
418
- "model_name" : "xgboost-model" ,
419
- "instance_type" : "ml.c5.xlarge" ,
420
- "initial_instance_count" : 1 ,
421
- },
424
+ "predictor" : expected_predictor_config ,
422
425
}
423
426
mock_method .assert_called_once_with (
424
427
data_config ,
@@ -429,3 +432,44 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
429
432
None ,
430
433
{"ExperimentName" : "AnExperiment" },
431
434
)
435
+
436
+
437
+ def test_shap (clarify_processor , data_config , model_config , shap_config ):
438
+ model_scores = None
439
+ expected_predictor_config = {
440
+ "model_name" : "xgboost-model" ,
441
+ "instance_type" : "ml.c5.xlarge" ,
442
+ "initial_instance_count" : 1 ,
443
+ }
444
+ _run_test_shap (
445
+ clarify_processor ,
446
+ data_config ,
447
+ model_config ,
448
+ shap_config ,
449
+ model_scores ,
450
+ expected_predictor_config ,
451
+ )
452
+
453
+
454
+ def test_shap_with_predicted_label (clarify_processor , data_config , model_config , shap_config ):
455
+ probability = "pr"
456
+ label_headers = ["success" ]
457
+ model_scores = ModelPredictedLabelConfig (
458
+ probability = probability ,
459
+ label_headers = label_headers ,
460
+ )
461
+ expected_predictor_config = {
462
+ "model_name" : "xgboost-model" ,
463
+ "instance_type" : "ml.c5.xlarge" ,
464
+ "initial_instance_count" : 1 ,
465
+ "probability" : probability ,
466
+ "label_headers" : label_headers ,
467
+ }
468
+ _run_test_shap (
469
+ clarify_processor ,
470
+ data_config ,
471
+ model_config ,
472
+ shap_config ,
473
+ model_scores ,
474
+ expected_predictor_config ,
475
+ )
0 commit comments