Skip to content

Commit 769d30c

Browse files
committed
test: add explainer config integ test
1 parent fda4827 commit 769d30c

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

tests/integ/test_explainer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
SHAP_BASELINE = "1,2,3,4,5,6,7"
4141
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")
4242

43+
CLARIFY_SHAP_BASELINE_CONFIG = ClarifyShapBaselineConfig(shap_baseline=SHAP_BASELINE)
44+
CLARIFY_SHAP_CONFIG = ClarifyShapConfig(shap_baseline_config=CLARIFY_SHAP_BASELINE_CONFIG)
45+
CLARIFY_EXPLAINER_CONFIG = ClarifyExplainerConfig(shap_config=CLARIFY_SHAP_CONFIG)
46+
EXPLAINER_CONFIG = ExplainerConfig(clarify_explainer_config=CLARIFY_EXPLAINER_CONFIG)
47+
4348

4449
@pytest.yield_fixture(scope="module")
4550
def endpoint_name(sagemaker_session):
@@ -66,19 +71,26 @@ def endpoint_name(sagemaker_session):
6671
role=ROLE,
6772
sagemaker_session=sagemaker_session,
6873
)
69-
clarify_shap_baseline_config = ClarifyShapBaselineConfig(shap_baseline=SHAP_BASELINE)
70-
clarify_shap_config = ClarifyShapConfig(shap_baseline_config=clarify_shap_baseline_config)
71-
clarify_explainer_config = ClarifyExplainerConfig(shap_config=clarify_shap_config)
72-
explainer_config = ExplainerConfig(clarify_explainer_config=clarify_explainer_config)
7374
xgb_model.deploy(
7475
INSTANCE_COUNT,
7576
INSTANCE_TYPE,
7677
endpoint_name=endpoint_name,
77-
explainer_config=explainer_config,
78+
explainer_config=EXPLAINER_CONFIG,
7879
)
7980
yield endpoint_name
8081

8182

83+
def test_describe_explainer_config(sagemaker_session, endpoint_name):
84+
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
85+
EndpointName=endpoint_name
86+
)
87+
88+
endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
89+
EndpointConfigName=endpoint_desc["EndpointConfigName"]
90+
)
91+
assert endpoint_config_desc["ExplainerConfig"] == EXPLAINER_CONFIG._to_request_dict()
92+
93+
8294
def test_invoke_explainer_enabled_endpoint(sagemaker_session, endpoint_name):
8395
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
8496
EndpointName=endpoint_name,

0 commit comments

Comments
 (0)