Skip to content

Commit e7ad3a9

Browse files
committed
fix: address feedbacks
1 parent e9bce93 commit e7ad3a9

File tree

12 files changed

+184
-103
lines changed

12 files changed

+184
-103
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ def deploy(
13781378
model_data_download_timeout=None,
13791379
container_startup_health_check_timeout=None,
13801380
inference_recommendation_id=None,
1381+
explainer_config=None,
13811382
**kwargs,
13821383
):
13831384
"""Deploy the trained model to an Amazon SageMaker endpoint.
@@ -1458,6 +1459,8 @@ def deploy(
14581459
inference_recommendation_id (str): The recommendation id which specifies the
14591460
recommendation you picked from inference recommendation job results and
14601461
would like to deploy the model and endpoint with recommended parameters.
1462+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
1463+
configuration for use with Amazon SageMaker Clarify. (default: None)
14611464
**kwargs: Passed to invocation of ``create_model()``.
14621465
Implementations may customize ``create_model()`` to accept
14631466
``**kwargs`` to customize model creation during deploy.
@@ -1516,6 +1519,7 @@ def deploy(
15161519
data_capture_config=data_capture_config,
15171520
serverless_inference_config=serverless_inference_config,
15181521
async_inference_config=async_inference_config,
1522+
explainer_config=explainer_config,
15191523
volume_size=volume_size,
15201524
model_data_download_timeout=model_data_download_timeout,
15211525
container_startup_health_check_timeout=container_startup_health_check_timeout,

src/sagemaker/explainer/clarify_explainer_config.py

Lines changed: 80 additions & 76 deletions
Large diffs are not rendered by default.

src/sagemaker/explainer/explainer_config.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,25 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""A class for ExplainerConfig
14-
15-
Use ExplainerConfig to activate explainers.
16-
"""
13+
"""A member of ``CreateEndpointConfig`` that enables explainers."""
1714

1815
from __future__ import print_function, absolute_import
16+
from typing import Optional
1917
from sagemaker.explainer.clarify_explainer_config import ClarifyExplainerConfig
2018

2119

2220
class ExplainerConfig(object):
23-
"""Config object to activate explainers."""
21+
"""A parameter to activate explainers."""
2422

2523
def __init__(
2624
self,
27-
clarify_explainer_config: ClarifyExplainerConfig = None,
25+
clarify_explainer_config: Optional[ClarifyExplainerConfig] = None,
2826
):
2927
"""Initializes a config object to activate explainer.
3028
3129
Args:
3230
clarify_explainer_config (:class:`~sagemaker.explainer.explainer_config.ClarifyExplainerConfig`):
33-
A config contains parameters for the SageMaker Clarify explainer. (Default: None)
31+
Optional. A config contains parameters for the SageMaker Clarify explainer. (Default: None)
3432
""" # noqa E501 # pylint: disable=line-too-long
3533
self.clarify_explainer_config = clarify_explainer_config
3634

src/sagemaker/huggingface/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def deploy(
210210
model_data_download_timeout=None,
211211
container_startup_health_check_timeout=None,
212212
inference_recommendation_id=None,
213+
explainer_config=None,
213214
**kwargs,
214215
):
215216
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -286,6 +287,8 @@ def deploy(
286287
inference_recommendation_id (str): The recommendation id which specifies the
287288
recommendation you picked from inference recommendation job results and
288289
would like to deploy the model and endpoint with recommended parameters.
290+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
291+
configuration for use with Amazon SageMaker Clarify. (default: None)
289292
Raises:
290293
ValueError: If arguments combination check failed in these circumstances:
291294
- If no role is specified or
@@ -322,6 +325,7 @@ def deploy(
322325
model_data_download_timeout=model_data_download_timeout,
323326
container_startup_health_check_timeout=container_startup_health_check_timeout,
324327
inference_recommendation_id=inference_recommendation_id,
328+
explainer_config=explainer_config,
325329
)
326330

327331
def register(

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _update_params(
215215
accelerator_type = kwargs["accelerator_type"]
216216
async_inference_config = kwargs["async_inference_config"]
217217
serverless_inference_config = kwargs["serverless_inference_config"]
218+
explainer_config = kwargs["explainer_config"]
218219
inference_recommendation_id = kwargs["inference_recommendation_id"]
219220
inference_recommender_job_results = kwargs["inference_recommender_job_results"]
220221
if inference_recommendation_id is not None:
@@ -225,6 +226,7 @@ def _update_params(
225226
async_inference_config=async_inference_config,
226227
serverless_inference_config=serverless_inference_config,
227228
inference_recommendation_id=inference_recommendation_id,
229+
explainer_config=explainer_config,
228230
)
229231
elif inference_recommender_job_results is not None:
230232
inference_recommendation = self._update_params_for_right_size(
@@ -233,6 +235,7 @@ def _update_params(
233235
accelerator_type,
234236
serverless_inference_config,
235237
async_inference_config,
238+
explainer_config,
236239
)
237240
return inference_recommendation or (instance_type, initial_instance_count)
238241

@@ -243,6 +246,7 @@ def _update_params_for_right_size(
243246
accelerator_type=None,
244247
serverless_inference_config=None,
245248
async_inference_config=None,
249+
explainer_config=None,
246250
):
247251
"""Validates that Inference Recommendation parameters can be used in `model.deploy()`
248252
@@ -262,6 +266,8 @@ def _update_params_for_right_size(
262266
whether serverless_inference_config has been passed into `model.deploy()`.
263267
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig):
264268
whether async_inference_config has been passed into `model.deploy()`.
269+
explainer_config (sagemaker.explainer.ExplainerConfig): whether explainer_config
270+
has been passed into `model.deploy()`.
265271
266272
Returns:
267273
(string, int) or None: Top instance_type and associated initial_instance_count
@@ -285,6 +291,11 @@ def _update_params_for_right_size(
285291
"serverless_inference_config is specified. Overriding right_size() recommendations."
286292
)
287293
return None
294+
if explainer_config:
295+
LOGGER.warning(
296+
"explainer_config is specified. Overriding right_size() recommendations."
297+
)
298+
return None
288299

289300
instance_type = self.inference_recommendations[0]["EndpointConfiguration"]["InstanceType"]
290301
initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][
@@ -300,6 +311,7 @@ def _update_params_for_recommendation_id(
300311
async_inference_config,
301312
serverless_inference_config,
302313
inference_recommendation_id,
314+
explainer_config,
303315
):
304316
"""Update parameters with inference recommendation results.
305317
@@ -332,6 +344,8 @@ def _update_params_for_recommendation_id(
332344
the recommendation you picked from inference recommendation job
333345
results and would like to deploy the model and endpoint with
334346
recommended parameters.
347+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
348+
configuration for use with Amazon SageMaker Clarify. Default: None.
335349
Raises:
336350
ValueError: If arguments combination check failed in these circumstances:
337351
- If only one of instance type or instance count specified or
@@ -367,6 +381,8 @@ def _update_params_for_recommendation_id(
367381
raise ValueError(
368382
"serverless_inference_config is not compatible with inference_recommendation_id."
369383
)
384+
if explainer_config is not None:
385+
raise ValueError("explainer_config is not compatible with inference_recommendation_id.")
370386

371387
# Validate recommendation id
372388
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):

src/sagemaker/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,11 +1077,11 @@ def deploy(
10771077
data_capture_config=None,
10781078
async_inference_config=None,
10791079
serverless_inference_config=None,
1080-
explainer_config=None,
10811080
volume_size=None,
10821081
model_data_download_timeout=None,
10831082
container_startup_health_check_timeout=None,
10841083
inference_recommendation_id=None,
1084+
explainer_config=None,
10851085
**kwargs,
10861086
):
10871087
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1147,8 +1147,6 @@ def deploy(
11471147
empty object passed through, will use pre-defined values in
11481148
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
11491149
instance based endpoint if it's None. (default: None)
1150-
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
1151-
configuration for use with Amazon SageMaker Clarify. Default: None.
11521150
volume_size (int): The size, in GB, of the ML storage volume attached to individual
11531151
inference instance associated with the production variant. Currenly only Amazon EBS
11541152
gp2 storage volumes are supported.
@@ -1162,6 +1160,8 @@ def deploy(
11621160
inference_recommendation_id (str): The recommendation id which specifies the
11631161
recommendation you picked from inference recommendation job results and
11641162
would like to deploy the model and endpoint with recommended parameters.
1163+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
1164+
configuration for use with Amazon SageMaker Clarify. Default: None.
11651165
Raises:
11661166
ValueError: If arguments combination check failed in these circumstances:
11671167
- If no role is specified or

src/sagemaker/predictor.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def predict(
127127
target_model=None,
128128
target_variant=None,
129129
inference_id=None,
130-
enable_explanations=None,
131130
):
132131
"""Return the inference from the specified endpoint.
133132
@@ -148,8 +147,6 @@ def predict(
148147
model you want to host and the resources you want to deploy for hosting it.
149148
inference_id (str): If you provide a value, it is added to the captured data
150149
when you enable data capture on the endpoint (Default: None).
151-
enable_explanations (str): An optional JMESPath expression used to override the
152-
EnableExplanations parameter of the ClarifyExplainerConfig. (Default: None).
153150
154151
Returns:
155152
object: Inference for the given input. If a deserializer was specified when creating
@@ -159,7 +156,7 @@ def predict(
159156
"""
160157

161158
request_args = self._create_request_args(
162-
data, initial_args, target_model, target_variant, inference_id, enable_explanations
159+
data, initial_args, target_model, target_variant, inference_id
163160
)
164161
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
165162
return self._handle_response(response)
@@ -177,7 +174,6 @@ def _create_request_args(
177174
target_model=None,
178175
target_variant=None,
179176
inference_id=None,
180-
enable_explanations=None,
181177
):
182178
"""Placeholder docstring"""
183179
args = dict(initial_args) if initial_args else {}
@@ -200,9 +196,6 @@ def _create_request_args(
200196
if inference_id:
201197
args["InferenceId"] = inference_id
202198

203-
if enable_explanations:
204-
args["EnableExplanations"] = enable_explanations
205-
206199
data = self.serializer.serialize(data)
207200

208201
args["Body"] = data

src/sagemaker/session.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3660,10 +3660,10 @@ def create_endpoint_config(
36603660
tags=None,
36613661
kms_key=None,
36623662
data_capture_config_dict=None,
3663-
explainer_config_dict=None,
36643663
volume_size=None,
36653664
model_data_download_timeout=None,
36663665
container_startup_health_check_timeout=None,
3666+
explainer_config_dict=None,
36673667
):
36683668
"""Create an Amazon SageMaker endpoint configuration.
36693669
@@ -3687,8 +3687,6 @@ def create_endpoint_config(
36873687
attached to the instance hosting the endpoint.
36883688
data_capture_config_dict (dict): Specifies configuration related to Endpoint data
36893689
capture for use with Amazon SageMaker Model Monitoring. Default: None.
3690-
explainer_config_dict (dict): Specifies configuration to enable explainers.
3691-
Default: None.
36923690
volume_size (int): The size, in GB, of the ML storage volume attached to individual
36933691
inference instance associated with the production variant. Currenly only Amazon EBS
36943692
gp2 storage volumes are supported.
@@ -3699,6 +3697,8 @@ def create_endpoint_config(
36993697
inference container to pass health check by SageMaker Hosting. For more information
37003698
about health check see:
37013699
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
3700+
explainer_config_dict (dict): Specifies configuration to enable explainers.
3701+
Default: None.
37023702
37033703
Example:
37043704
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
@@ -3767,8 +3767,8 @@ def create_endpoint_config_from_existing(
37673767
new_tags=None,
37683768
new_kms_key=None,
37693769
new_data_capture_config_dict=None,
3770-
new_explainer_config_dict=None,
37713770
new_production_variants=None,
3771+
new_explainer_config_dict=None,
37723772
):
37733773
"""Create an Amazon SageMaker endpoint configuration from an existing one.
37743774
@@ -3793,12 +3793,12 @@ def create_endpoint_config_from_existing(
37933793
capture for use with Amazon SageMaker Model Monitoring (default: None).
37943794
If not specified, the data capture configuration of the existing
37953795
endpoint configuration is used.
3796-
new_explainer_config_dict (dict): Specifies configuration to enable explainers.
3797-
(default: None). If not specified, the explainer configuration of the existing
3798-
endpoint configuration is used.
37993796
new_production_variants (list[dict]): The configuration for which model(s) to host and
38003797
the resources to deploy for hosting the model(s). If not specified,
38013798
the ``ProductionVariants`` of the existing endpoint configuration is used.
3799+
new_explainer_config_dict (dict): Specifies configuration to enable explainers.
3800+
(default: None). If not specified, the explainer configuration of the existing
3801+
endpoint configuration is used.
38023802
38033803
Returns:
38043804
str: Name of the endpoint point configuration created.
@@ -4407,7 +4407,7 @@ def endpoint_from_production_variants(
44074407
async_inference_config_dict (dict) : specifies configuration related to async endpoint.
44084408
Use this configuration when trying to create async endpoint and make async inference
44094409
(default: None)
4410-
explainer_config_dict (dict) : specifies configuration related to explainer.
4410+
explainer_config_dict (dict) : Specifies configuration related to explainer.
44114411
Use this configuration when trying to use online explainability.
44124412
(default: None)
44134413
Returns:

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def deploy(
324324
model_data_download_timeout=None,
325325
container_startup_health_check_timeout=None,
326326
inference_recommendation_id=None,
327+
explainer_config=None,
327328
):
328329
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""
329330

@@ -349,6 +350,7 @@ def deploy(
349350
container_startup_health_check_timeout=container_startup_health_check_timeout,
350351
update_endpoint=update_endpoint,
351352
inference_recommendation_id=inference_recommendation_id,
353+
explainer_config=explainer_config,
352354
)
353355

354356
def _eia_supported(self):

tests/integ/test_explainer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242

4343
CLARIFY_SHAP_BASELINE_CONFIG = ClarifyShapBaselineConfig(shap_baseline=SHAP_BASELINE)
4444
CLARIFY_SHAP_CONFIG = ClarifyShapConfig(shap_baseline_config=CLARIFY_SHAP_BASELINE_CONFIG)
45-
CLARIFY_EXPLAINER_CONFIG = ClarifyExplainerConfig(shap_config=CLARIFY_SHAP_CONFIG)
45+
CLARIFY_EXPLAINER_CONFIG = ClarifyExplainerConfig(
46+
shap_config=CLARIFY_SHAP_CONFIG, enable_explanations="`true`"
47+
)
4648
EXPLAINER_CONFIG = ExplainerConfig(clarify_explainer_config=CLARIFY_EXPLAINER_CONFIG)
4749

4850

@@ -107,3 +109,23 @@ def test_invoke_explainer_enabled_endpoint(sagemaker_session, endpoint_name):
107109
assert response_body_json.get("predictions")
108110
finally:
109111
response_body_stream.close()
112+
113+
114+
def test_invoke_endpoint_with_on_demand_explanations(sagemaker_session, endpoint_name):
115+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
116+
EndpointName=endpoint_name,
117+
EnableExplanations="`false`",
118+
Body=TEST_CSV_DATA,
119+
ContentType="text/csv",
120+
Accept="text/csv",
121+
)
122+
assert response
123+
response_body_stream = response["Body"]
124+
try:
125+
response_body_json = json.load(codecs.getreader("utf-8")(response_body_stream))
126+
assert response_body_json
127+
# no records are explained when EnableExplanations="`false`"
128+
assert response_body_json.get("explanations") == {}
129+
assert response_body_json.get("predictions")
130+
finally:
131+
response_body_stream.close()

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from sagemaker.async_inference import AsyncInferenceConfig
1212
from sagemaker.serverless import ServerlessInferenceConfig
13+
from sagemaker.explainer import ExplainerConfig
1314

1415
import pytest
1516

@@ -667,5 +668,34 @@ def test_deploy_right_size_async_override(sagemaker_session, default_right_sized
667668
)
668669

669670

671+
@patch("sagemaker.utils.name_from_base", MagicMock(return_value=MODEL_NAME))
672+
def test_deploy_right_size_explainer_config_override(sagemaker_session, default_right_sized_model):
673+
default_right_sized_model.name = MODEL_NAME
674+
mock_clarify_explainer_config = MagicMock()
675+
mock_clarify_explainer_config_dict = {
676+
"EnableExplanations": "`false`",
677+
}
678+
mock_clarify_explainer_config._to_request_dict.return_value = mock_clarify_explainer_config_dict
679+
explainer_config = ExplainerConfig(clarify_explainer_config=mock_clarify_explainer_config)
680+
explainer_config_dict = {"ClarifyExplainerConfig": mock_clarify_explainer_config_dict}
681+
682+
default_right_sized_model.deploy(
683+
instance_type="ml.c5.2xlarge",
684+
initial_instance_count=1,
685+
explainer_config=explainer_config,
686+
)
687+
688+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
689+
name=MODEL_NAME,
690+
production_variants=[ANY],
691+
tags=None,
692+
kms_key=None,
693+
wait=True,
694+
data_capture_config_dict=None,
695+
async_inference_config_dict=None,
696+
explainer_config_dict=explainer_config_dict,
697+
)
698+
699+
670700
# TODO -> cover inference_recommendation_id cases
671701
# ...

0 commit comments

Comments
 (0)