@@ -442,21 +442,22 @@ def test_post_training_bias(
442
442
)
443
443
444
444
445
- @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
446
- def test_shap (
445
+ def _run_test_shap (
447
446
name_from_base ,
448
447
clarify_processor ,
449
448
clarify_processor_with_job_name_prefix ,
450
449
data_config ,
451
450
model_config ,
452
451
shap_config ,
452
+ model_scores ,
453
+ expected_predictor_config ,
453
454
):
454
455
with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
455
456
clarify_processor .run_explainability (
456
457
data_config ,
457
458
model_config ,
458
459
shap_config ,
459
- model_scores = None ,
460
+ model_scores = model_scores ,
460
461
wait = True ,
461
462
job_name = "test" ,
462
463
experiment_config = {"ExperimentName" : "AnExperiment" },
@@ -485,11 +486,7 @@ def test_shap(
485
486
"save_local_shap_values" : True ,
486
487
}
487
488
},
488
- "predictor" : {
489
- "model_name" : "xgboost-model" ,
490
- "instance_type" : "ml.c5.xlarge" ,
491
- "initial_instance_count" : 1 ,
492
- },
489
+ "predictor" : expected_predictor_config ,
493
490
}
494
491
mock_method .assert_called_with (
495
492
data_config ,
@@ -504,7 +501,7 @@ def test_shap(
504
501
data_config ,
505
502
model_config ,
506
503
shap_config ,
507
- model_scores = None ,
504
+ model_scores = model_scores ,
508
505
wait = True ,
509
506
experiment_config = {"ExperimentName" : "AnExperiment" },
510
507
)
@@ -518,3 +515,63 @@ def test_shap(
518
515
None ,
519
516
{"ExperimentName" : "AnExperiment" },
520
517
)
518
+
519
+
520
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
521
+ def test_shap (
522
+ name_from_base ,
523
+ clarify_processor ,
524
+ clarify_processor_with_job_name_prefix ,
525
+ data_config ,
526
+ model_config ,
527
+ shap_config ,
528
+ ):
529
+ expected_predictor_config = {
530
+ "model_name" : "xgboost-model" ,
531
+ "instance_type" : "ml.c5.xlarge" ,
532
+ "initial_instance_count" : 1 ,
533
+ }
534
+ _run_test_shap (
535
+ name_from_base ,
536
+ clarify_processor ,
537
+ clarify_processor_with_job_name_prefix ,
538
+ data_config ,
539
+ model_config ,
540
+ shap_config ,
541
+ None ,
542
+ expected_predictor_config ,
543
+ )
544
+
545
+
546
+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
547
+ def test_shap_with_predicted_label (
548
+ name_from_base ,
549
+ clarify_processor ,
550
+ clarify_processor_with_job_name_prefix ,
551
+ data_config ,
552
+ model_config ,
553
+ shap_config ,
554
+ ):
555
+ probability = "pr"
556
+ label_headers = ["success" ]
557
+ model_scores = ModelPredictedLabelConfig (
558
+ probability = probability ,
559
+ label_headers = label_headers ,
560
+ )
561
+ expected_predictor_config = {
562
+ "model_name" : "xgboost-model" ,
563
+ "instance_type" : "ml.c5.xlarge" ,
564
+ "initial_instance_count" : 1 ,
565
+ "probability" : probability ,
566
+ "label_headers" : label_headers ,
567
+ }
568
+ _run_test_shap (
569
+ name_from_base ,
570
+ clarify_processor ,
571
+ clarify_processor_with_job_name_prefix ,
572
+ data_config ,
573
+ model_config ,
574
+ shap_config ,
575
+ model_scores ,
576
+ expected_predictor_config ,
577
+ )
0 commit comments