Skip to content

Commit ab0a3fe

Browse files
authored
Merge branch 'dev' into support-for-pytorch-1-10-0
2 parents ca58ed9 + 8e9d9b7 commit ab0a3fe

15 files changed

+193
-21
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ venv/
2727
*.swp
2828
.docker/
2929
env/
30-
.vscode/
30+
.vscode/
31+
.python-version

src/sagemaker/clarify.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,15 @@ def __init__(
290290
probability_threshold (float): An optional value for binary prediction tasks in which
291291
the model returns a probability, to indicate the threshold to convert the
292292
prediction to a boolean value. Default is 0.5.
293-
label_headers (list): List of label values - one for each score of the ``probability``.
293+
label_headers (list[str]): List of headers, each for a predicted score in model output.
294+
For bias analysis, it is used to extract the label value with the highest score as
295+
predicted label. For explainability job, It is used to beautify the analysis report
296+
by replacing placeholders like "label0".
294297
"""
295298
self.label = label
296299
self.probability = probability
297300
self.probability_threshold = probability_threshold
301+
self.label_headers = label_headers
298302
if probability_threshold is not None:
299303
try:
300304
float(probability_threshold)
@@ -1060,10 +1064,10 @@ def run_explainability(
10601064
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
10611065
Config of the specific explainability method or a list of ExplainabilityConfig
10621066
objects. Currently, SHAP and PDP are the two methods supported.
1063-
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
1064-
model output for the predicted scores to be explained. This is not required if the
1065-
model output is a single score. Alternatively, an instance of
1066-
ModelPredictedLabelConfig can be provided.
1067+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
1068+
Index or JSONPath to locate the predicted scores in the model output. This is not
1069+
required if the model output is a single score. Alternatively, it can be an instance
1070+
of ModelPredictedLabelConfig to provide more parameters like label_headers.
10671071
wait (bool): Whether the call should wait until the job completes (default: True).
10681072
logs (bool): Whether to show the logs produced by the job.
10691073
Only meaningful when ``wait`` is True (default: True).

src/sagemaker/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -2343,6 +2343,7 @@ def _stage_user_code_in_s3(self):
23432343
dependencies=self.dependencies,
23442344
kms_key=kms_key,
23452345
s3_resource=self.sagemaker_session.s3_resource,
2346+
settings=self.sagemaker_session.settings,
23462347
)
23472348

23482349
def _model_source_dir(self):

src/sagemaker/fw_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import shutil
2020
import tempfile
2121
from collections import namedtuple
22+
from typing import Optional
2223

2324
import sagemaker.image_uris
25+
from sagemaker.session_settings import SessionSettings
2426
import sagemaker.utils
2527

2628
from sagemaker.deprecations import renamed_warning
@@ -216,6 +218,7 @@ def tar_and_upload_dir(
216218
dependencies=None,
217219
kms_key=None,
218220
s3_resource=None,
221+
settings: Optional[SessionSettings] = None,
219222
):
220223
"""Package source files and upload a compress tar file to S3.
221224
@@ -243,6 +246,9 @@ def tar_and_upload_dir(
243246
s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
244247
for S3 connections, can be used to customize the configuration,
245248
e.g. set the endpoint URL (default: None).
249+
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
250+
of the SageMaker ``Session``, can be used to override the default encryption
251+
behavior (default: None).
246252
Returns:
247253
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
248254
script name.
@@ -254,6 +260,7 @@ def tar_and_upload_dir(
254260
dependencies = dependencies or []
255261
key = "%s/sourcedir.tar.gz" % s3_key_prefix
256262
tmp = tempfile.mkdtemp()
263+
encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts
257264

258265
try:
259266
source_files = _list_files_to_compress(script, directory) + dependencies
@@ -263,6 +270,10 @@ def tar_and_upload_dir(
263270

264271
if kms_key:
265272
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
273+
elif encrypt_artifact:
274+
# encrypt the tarball at rest in S3 with the default AWS managed KMS key for S3
275+
# see https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html#API_PutObject_RequestSyntax
276+
extra_args = {"ServerSideEncryption": "aws:kms"}
266277
else:
267278
extra_args = None
268279

src/sagemaker/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def _upload_code(self, key_prefix, repack=False):
11311131
script=self.entry_point,
11321132
directory=self.source_dir,
11331133
dependencies=self.dependencies,
1134+
settings=self.sagemaker_session.settings,
11341135
)
11351136

11361137
if repack and self.model_data is not None and self.entry_point is not None:

src/sagemaker/model_monitor/clarify_model_monitoring.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker import image_uris, s3
2727
from sagemaker.session import Session
2828
from sagemaker.utils import name_from_base
29-
from sagemaker.clarify import SageMakerClarifyProcessor
29+
from sagemaker.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig
3030

3131
_LOGGER = logging.getLogger(__name__)
3232

@@ -833,9 +833,10 @@ def suggest_baseline(
833833
specific explainability method. Currently, only SHAP is supported.
834834
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
835835
endpoint to be created.
836-
model_scores (int or str): Index or JSONPath location in the model output for the
837-
predicted scores to be explained. This is not required if the model output is
838-
a single score.
836+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
837+
Index or JSONPath to locate the predicted scores in the model output. This is not
838+
required if the model output is a single score. Alternatively, it can be an instance
839+
of ModelPredictedLabelConfig to provide more parameters like label_headers.
839840
wait (bool): Whether the call should wait until the job completes (default: False).
840841
logs (bool): Whether to show the logs produced by the job.
841842
Only meaningful when wait is True (default: False).
@@ -865,14 +866,24 @@ def suggest_baseline(
865866
headers = copy.deepcopy(data_config.headers)
866867
if headers and data_config.label in headers:
867868
headers.remove(data_config.label)
869+
if model_scores is None:
870+
inference_attribute = None
871+
label_headers = None
872+
elif isinstance(model_scores, ModelPredictedLabelConfig):
873+
inference_attribute = str(model_scores.label)
874+
label_headers = model_scores.label_headers
875+
else:
876+
inference_attribute = str(model_scores)
877+
label_headers = None
868878
self.latest_baselining_job_config = ClarifyBaseliningConfig(
869879
analysis_config=ExplainabilityAnalysisConfig(
870880
explainability_config=explainability_config,
871881
model_config=model_config,
872882
headers=headers,
883+
label_headers=label_headers,
873884
),
874885
features_attribute=data_config.features,
875-
inference_attribute=model_scores if model_scores is None else str(model_scores),
886+
inference_attribute=inference_attribute,
876887
)
877888
self.latest_baselining_job_name = baselining_job_name
878889
self.latest_baselining_job = ClarifyBaseliningJob(
@@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
11661177
class ExplainabilityAnalysisConfig:
11671178
"""Analysis configuration for ModelExplainabilityMonitor."""
11681179

1169-
def __init__(self, explainability_config, model_config, headers=None):
1180+
def __init__(self, explainability_config, model_config, headers=None, label_headers=None):
11701181
"""Creates an analysis config dictionary.
11711182
11721183
Args:
@@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None):
11751186
model_config (sagemaker.clarify.ModelConfig): Config object related to bias
11761187
configurations.
11771188
headers (list[str]): A list of feature names (without label) of model/endpint input.
1189+
label_headers (list[str]): List of headers, each for a predicted score in model output.
1190+
It is used to beautify the analysis report by replacing placeholders like "label0".
1191+
11781192
"""
1193+
predictor_config = model_config.get_predictor_config()
11791194
self.analysis_config = {
11801195
"methods": explainability_config.get_explainability_config(),
1181-
"predictor": model_config.get_predictor_config(),
1196+
"predictor": predictor_config,
11821197
}
11831198
if headers is not None:
11841199
self.analysis_config["headers"] = headers
1200+
if label_headers is not None:
1201+
predictor_config["label_headers"] = label_headers
11851202

11861203
def _to_dict(self):
11871204
"""Generates a request dictionary using the parameters provided to the class."""

src/sagemaker/session.py

+5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
sts_regional_endpoint,
4343
)
4444
from sagemaker import exceptions
45+
from sagemaker.session_settings import SessionSettings
4546

4647
LOGGER = logging.getLogger("sagemaker")
4748

@@ -85,6 +86,7 @@ def __init__(
8586
sagemaker_runtime_client=None,
8687
sagemaker_featurestore_runtime_client=None,
8788
default_bucket=None,
89+
settings=SessionSettings(),
8890
):
8991
"""Initialize a SageMaker ``Session``.
9092
@@ -110,13 +112,16 @@ def __init__(
110112
If not provided, a default bucket will be created based on the following format:
111113
"sagemaker-{region}-{aws-account-id}".
112114
Example: "sagemaker-my-custom-bucket".
115+
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
116+
parameters to apply to the session.
113117
"""
114118
self._default_bucket = None
115119
self._default_bucket_name_override = default_bucket
116120
self.s3_resource = None
117121
self.s3_client = None
118122
self.config = None
119123
self.lambda_client = None
124+
self.settings = settings
120125

121126
self._initialize(
122127
boto_session=boto_session,

src/sagemaker/session_settings.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Defines classes to parametrize a SageMaker ``Session``."""
14+
15+
from __future__ import absolute_import
16+
17+
18+
class SessionSettings(object):
19+
"""Optional container class for settings to apply to a SageMaker session."""
20+
21+
def __init__(self, encrypt_repacked_artifacts=True) -> None:
22+
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.
23+
24+
Args:
25+
encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts
26+
at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key
27+
is not provided (Default: True).
28+
"""
29+
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts
30+
31+
@property
32+
def encrypt_repacked_artifacts(self) -> bool:
33+
"""Return True if repacked artifacts at rest in S3 should be encrypted by default."""
34+
return self._encrypt_repacked_artifacts

src/sagemaker/utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from six.moves.urllib import parse
3030

3131
from sagemaker import deprecations
32+
from sagemaker.session_settings import SessionSettings
3233

3334

3435
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
@@ -429,8 +430,15 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
429430
bucket, key = url.netloc, url.path.lstrip("/")
430431
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
431432

433+
settings = (
434+
sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
435+
)
436+
encrypt_artifact = settings.encrypt_repacked_artifacts
437+
432438
if kms_key:
433439
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
440+
elif encrypt_artifact:
441+
extra_args = {"ServerSideEncryption": "aws:kms"}
434442
else:
435443
extra_args = None
436444
sagemaker_session.boto_session.resource(

tests/integ/test_clarify_model_monitor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
HEADER_OF_LABEL = "Label"
5454
HEADERS_OF_FEATURES = ["F1", "F2", "F3", "F4", "F5", "F6", "F7"]
5555
ALL_HEADERS = [*HEADERS_OF_FEATURES, HEADER_OF_LABEL]
56+
HEADER_OF_PREDICTION = "Decision"
5657
DATASET_TYPE = "text/csv"
5758
CONTENT_TYPE = DATASET_TYPE
5859
ACCEPT_TYPE = DATASET_TYPE
@@ -325,7 +326,7 @@ def scheduled_explainability_monitor(
325326
):
326327
monitor_schedule_name = utils.unique_name_from_base("explainability-monitor")
327328
analysis_config = ExplainabilityAnalysisConfig(
328-
shap_config, model_config, headers=HEADERS_OF_FEATURES
329+
shap_config, model_config, headers=HEADERS_OF_FEATURES, label_headers=[HEADER_OF_PREDICTION]
329330
)
330331
s3_uri_monitoring_output = os.path.join(
331332
"s3://",

tests/unit/sagemaker/monitor/test_clarify_model_monitor.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@
279279
# for bias
280280
ANALYSIS_CONFIG_LABEL = "Label"
281281
ANALYSIS_CONFIG_HEADERS_OF_FEATURES = ["F1", "F2", "F3"]
282+
ANALYSIS_CONFIG_LABEL_HEADERS = ["Decision"]
282283
ANALYSIS_CONFIG_ALL_HEADERS = [*ANALYSIS_CONFIG_HEADERS_OF_FEATURES, ANALYSIS_CONFIG_LABEL]
283284
ANALYSIS_CONFIG_LABEL_VALUES = [1]
284285
ANALYSIS_CONFIG_FACET_NAME = "F1"
@@ -330,6 +331,11 @@
330331
"content_type": CONTENT_TYPE,
331332
},
332333
}
334+
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS = copy.deepcopy(EXPLAINABILITY_ANALYSIS_CONFIG)
335+
# noinspection PyTypeChecker
336+
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS["predictor"][
337+
"label_headers"
338+
] = ANALYSIS_CONFIG_LABEL_HEADERS
333339

334340

335341
@pytest.fixture()
@@ -1048,25 +1054,44 @@ def test_explainability_analysis_config(shap_config, model_config):
10481054
explainability_config=shap_config,
10491055
model_config=model_config,
10501056
headers=ANALYSIS_CONFIG_HEADERS_OF_FEATURES,
1057+
label_headers=ANALYSIS_CONFIG_LABEL_HEADERS,
10511058
)
1052-
assert EXPLAINABILITY_ANALYSIS_CONFIG == config._to_dict()
1059+
assert EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS == config._to_dict()
10531060

10541061

1062+
@pytest.mark.parametrize(
1063+
"model_scores,explainability_analysis_config",
1064+
[
1065+
(INFERENCE_ATTRIBUTE, EXPLAINABILITY_ANALYSIS_CONFIG),
1066+
(
1067+
ModelPredictedLabelConfig(
1068+
label=INFERENCE_ATTRIBUTE, label_headers=ANALYSIS_CONFIG_LABEL_HEADERS
1069+
),
1070+
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS,
1071+
),
1072+
],
1073+
)
10551074
def test_model_explainability_monitor_suggest_baseline(
1056-
model_explainability_monitor, sagemaker_session, data_config, shap_config, model_config
1075+
model_explainability_monitor,
1076+
sagemaker_session,
1077+
data_config,
1078+
shap_config,
1079+
model_config,
1080+
model_scores,
1081+
explainability_analysis_config,
10571082
):
10581083
clarify_model_monitor = model_explainability_monitor
10591084
# suggest baseline
10601085
clarify_model_monitor.suggest_baseline(
10611086
data_config=data_config,
10621087
explainability_config=shap_config,
10631088
model_config=model_config,
1064-
model_scores=INFERENCE_ATTRIBUTE,
1089+
model_scores=model_scores,
10651090
job_name=BASELINING_JOB_NAME,
10661091
)
10671092
assert isinstance(clarify_model_monitor.latest_baselining_job, ClarifyBaseliningJob)
10681093
assert (
1069-
EXPLAINABILITY_ANALYSIS_CONFIG
1094+
explainability_analysis_config
10701095
== clarify_model_monitor.latest_baselining_job_config.analysis_config._to_dict()
10711096
)
10721097
clarify_baselining_job = clarify_model_monitor.latest_baselining_job
@@ -1081,6 +1106,7 @@ def test_model_explainability_monitor_suggest_baseline(
10811106
analysis_config=None, # will pick up config from baselining job
10821107
baseline_job_name=BASELINING_JOB_NAME,
10831108
endpoint_input=ENDPOINT_NAME,
1109+
explainability_analysis_config=explainability_analysis_config,
10841110
# will pick up attributes from baselining job
10851111
)
10861112

@@ -1133,6 +1159,7 @@ def test_model_explainability_monitor_created_with_config(
11331159
sagemaker_session=sagemaker_session,
11341160
analysis_config=analysis_config,
11351161
constraints=CONSTRAINTS,
1162+
explainability_analysis_config=EXPLAINABILITY_ANALYSIS_CONFIG,
11361163
)
11371164

11381165
# update schedule
@@ -1263,6 +1290,7 @@ def _test_model_explainability_monitor_create_schedule(
12631290
features_attribute=FEATURES_ATTRIBUTE,
12641291
inference_attribute=str(INFERENCE_ATTRIBUTE),
12651292
),
1293+
explainability_analysis_config=None,
12661294
):
12671295
# create schedule
12681296
with patch(
@@ -1278,7 +1306,7 @@ def _test_model_explainability_monitor_create_schedule(
12781306
)
12791307
if not isinstance(analysis_config, str):
12801308
upload.assert_called_once()
1281-
assert json.loads(upload.call_args[0][0]) == EXPLAINABILITY_ANALYSIS_CONFIG
1309+
assert json.loads(upload.call_args[0][0]) == explainability_analysis_config
12821310

12831311
# validation
12841312
expected_arguments = {

0 commit comments

Comments
 (0)