Skip to content

Commit f007d03

Browse files
authored
Merge branch 'dev' into feature/large_pipeline
2 parents f69a544 + 8e9d9b7 commit f007d03

11 files changed

+125
-3
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ venv/
2727
*.swp
2828
.docker/
2929
env/
30-
.vscode/
30+
.vscode/
31+
.python-version

src/sagemaker/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -2343,6 +2343,7 @@ def _stage_user_code_in_s3(self):
23432343
dependencies=self.dependencies,
23442344
kms_key=kms_key,
23452345
s3_resource=self.sagemaker_session.s3_resource,
2346+
settings=self.sagemaker_session.settings,
23462347
)
23472348

23482349
def _model_source_dir(self):

src/sagemaker/fw_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import shutil
2020
import tempfile
2121
from collections import namedtuple
22+
from typing import Optional
2223

2324
import sagemaker.image_uris
25+
from sagemaker.session_settings import SessionSettings
2426
import sagemaker.utils
2527

2628
from sagemaker.deprecations import renamed_warning
@@ -203,6 +205,7 @@ def tar_and_upload_dir(
203205
dependencies=None,
204206
kms_key=None,
205207
s3_resource=None,
208+
settings: Optional[SessionSettings] = None,
206209
):
207210
"""Package source files and upload a compress tar file to S3.
208211
@@ -230,6 +233,9 @@ def tar_and_upload_dir(
230233
s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
231234
for S3 connections, can be used to customize the configuration,
232235
e.g. set the endpoint URL (default: None).
236+
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
237+
of the SageMaker ``Session``, can be used to override the default encryption
238+
behavior (default: None).
233239
Returns:
234240
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
235241
script name.
@@ -241,6 +247,7 @@ def tar_and_upload_dir(
241247
dependencies = dependencies or []
242248
key = "%s/sourcedir.tar.gz" % s3_key_prefix
243249
tmp = tempfile.mkdtemp()
250+
encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts
244251

245252
try:
246253
source_files = _list_files_to_compress(script, directory) + dependencies
@@ -250,6 +257,10 @@ def tar_and_upload_dir(
250257

251258
if kms_key:
252259
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
260+
elif encrypt_artifact:
261+
# encrypt the tarball at rest in S3 with the default AWS managed KMS key for S3
262+
# see https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html#API_PutObject_RequestSyntax
263+
extra_args = {"ServerSideEncryption": "aws:kms"}
253264
else:
254265
extra_args = None
255266

src/sagemaker/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def _upload_code(self, key_prefix, repack=False):
11311131
script=self.entry_point,
11321132
directory=self.source_dir,
11331133
dependencies=self.dependencies,
1134+
settings=self.sagemaker_session.settings,
11341135
)
11351136

11361137
if repack and self.model_data is not None and self.entry_point is not None:

src/sagemaker/session.py

+5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
sts_regional_endpoint,
4343
)
4444
from sagemaker import exceptions
45+
from sagemaker.session_settings import SessionSettings
4546

4647
LOGGER = logging.getLogger("sagemaker")
4748

@@ -85,6 +86,7 @@ def __init__(
8586
sagemaker_runtime_client=None,
8687
sagemaker_featurestore_runtime_client=None,
8788
default_bucket=None,
89+
settings=SessionSettings(),
8890
):
8991
"""Initialize a SageMaker ``Session``.
9092
@@ -110,13 +112,16 @@ def __init__(
110112
If not provided, a default bucket will be created based on the following format:
111113
"sagemaker-{region}-{aws-account-id}".
112114
Example: "sagemaker-my-custom-bucket".
115+
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
116+
parameters to apply to the session.
113117
"""
114118
self._default_bucket = None
115119
self._default_bucket_name_override = default_bucket
116120
self.s3_resource = None
117121
self.s3_client = None
118122
self.config = None
119123
self.lambda_client = None
124+
self.settings = settings
120125

121126
self._initialize(
122127
boto_session=boto_session,

src/sagemaker/session_settings.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Defines classes to parametrize a SageMaker ``Session``."""
14+
15+
from __future__ import absolute_import
16+
17+
18+
class SessionSettings(object):
19+
"""Optional container class for settings to apply to a SageMaker session."""
20+
21+
def __init__(self, encrypt_repacked_artifacts=True) -> None:
22+
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.
23+
24+
Args:
25+
encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts
26+
at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key
27+
is not provided (Default: True).
28+
"""
29+
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts
30+
31+
@property
32+
def encrypt_repacked_artifacts(self) -> bool:
33+
"""Return True if repacked artifacts at rest in S3 should be encrypted by default."""
34+
return self._encrypt_repacked_artifacts

src/sagemaker/utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from six.moves.urllib import parse
3030

3131
from sagemaker import deprecations
32+
from sagemaker.session_settings import SessionSettings
3233

3334

3435
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
@@ -429,8 +430,15 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
429430
bucket, key = url.netloc, url.path.lstrip("/")
430431
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
431432

433+
settings = (
434+
sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
435+
)
436+
encrypt_artifact = settings.encrypt_repacked_artifacts
437+
432438
if kms_key:
433439
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
440+
elif encrypt_artifact:
441+
extra_args = {"ServerSideEncryption": "aws:kms"}
434442
else:
435443
extra_args = None
436444
sagemaker_session.boto_session.resource(

tests/unit/test_estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2323,8 +2323,8 @@ def test_different_code_location_kms_key(utils, sagemaker_session):
23232323
obj = sagemaker_session.boto_session.resource("s3").Object
23242324

23252325
obj.assert_called_with("another-location", "%s/source/sourcedir.tar.gz" % fw._current_job_name)
2326-
2327-
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None)
2326+
extra_args = {"ServerSideEncryption": "aws:kms"}
2327+
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
23282328

23292329

23302330
@patch("sagemaker.utils")

tests/unit/test_fw_utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from sagemaker import fw_utils
2626
from sagemaker.utils import name_from_image
27+
from sagemaker.session_settings import SessionSettings
2728

2829
TIMESTAMP = "2017-10-10-14-14-15"
2930

@@ -93,6 +94,40 @@ def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
9394
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
9495

9596

97+
@patch("sagemaker.utils")
98+
def test_tar_and_upload_dir_s3_kms_enabled_by_default(utils, sagemaker_session):
99+
bucket = "mybucket"
100+
s3_key_prefix = "something/source"
101+
script = "inference.py"
102+
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script)
103+
104+
assert result == fw_utils.UploadedCode(
105+
"s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
106+
)
107+
108+
extra_args = {"ServerSideEncryption": "aws:kms"}
109+
obj = sagemaker_session.resource("s3").Object("", "")
110+
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
111+
112+
113+
@patch("sagemaker.utils")
114+
def test_tar_and_upload_dir_s3_without_kms_with_overridden_settings(utils, sagemaker_session):
115+
bucket = "mybucket"
116+
s3_key_prefix = "something/source"
117+
script = "inference.py"
118+
settings = SessionSettings(encrypt_repacked_artifacts=False)
119+
result = fw_utils.tar_and_upload_dir(
120+
sagemaker_session, bucket, s3_key_prefix, script, settings=settings
121+
)
122+
123+
assert result == fw_utils.UploadedCode(
124+
"s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
125+
)
126+
127+
obj = sagemaker_session.resource("s3").Object("", "")
128+
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None)
129+
130+
96131
def test_mp_config_partition_exists():
97132
mp_parameters = {}
98133
with pytest.raises(ValueError):

tests/unit/test_utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from mock import call, patch, Mock, MagicMock
2828

2929
import sagemaker
30+
from sagemaker.session_settings import SessionSettings
3031

3132
BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"
3233

@@ -390,6 +391,13 @@ def test_repack_model_without_source_dir(tmp, fake_s3):
390391
"/code/inference.py",
391392
}
392393

394+
extra_args = {"ServerSideEncryption": "aws:kms"}
395+
object_mock = fake_s3.object_mock
396+
_, _, kwargs = object_mock.mock_calls[0]
397+
398+
assert "ExtraArgs" in kwargs
399+
assert kwargs["ExtraArgs"] == extra_args
400+
393401

394402
def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake_s3):
395403

@@ -415,12 +423,20 @@ def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake
415423
"s3://fake/location",
416424
"s3://destination-bucket/model.tar.gz",
417425
fake_s3.sagemaker_session,
426+
kms_key="kms_key",
418427
)
419428
finally:
420429
os.chdir(cwd)
421430

422431
assert list_tar_files(fake_s3.fake_upload_path, tmp) == {"/code/inference.py", "/model"}
423432

433+
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": "kms_key"}
434+
object_mock = fake_s3.object_mock
435+
_, _, kwargs = object_mock.mock_calls[0]
436+
437+
assert "ExtraArgs" in kwargs
438+
assert kwargs["ExtraArgs"] == extra_args
439+
424440

425441
def test_repack_model_from_s3_to_s3(tmp, fake_s3):
426442

@@ -434,6 +450,7 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
434450
)
435451

436452
fake_s3.tar_and_upload("model-dir", "s3://fake/location")
453+
fake_s3.sagemaker_session.settings = SessionSettings(encrypt_repacked_artifacts=False)
437454

438455
sagemaker.utils.repack_model(
439456
"inference.py",
@@ -450,6 +467,11 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
450467
"/model",
451468
}
452469

470+
object_mock = fake_s3.object_mock
471+
_, _, kwargs = object_mock.mock_calls[0]
472+
assert "ExtraArgs" in kwargs
473+
assert kwargs["ExtraArgs"] is None
474+
453475

454476
def test_repack_model_from_file_to_file(tmp):
455477
create_file_tree(tmp, ["model", "dependencies/a", "source-dir/inference.py"])
@@ -581,6 +603,7 @@ def __init__(self, tmp):
581603
self.sagemaker_session = MagicMock()
582604
self.location_map = {}
583605
self.current_bucket = None
606+
self.object_mock = MagicMock()
584607

585608
self.sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = (
586609
self.download_file
@@ -606,6 +629,7 @@ def tar_and_upload(self, path, fake_location):
606629

607630
def mock_s3_upload(self):
608631
dst = os.path.join(self.tmp, "dst")
632+
object_mock = self.object_mock
609633

610634
class MockS3Object(object):
611635
def __init__(self, bucket, key):
@@ -616,6 +640,7 @@ def upload_file(self, target, **kwargs):
616640
if self.bucket in BUCKET_WITHOUT_WRITING_PERMISSION:
617641
raise exceptions.S3UploadFailedError()
618642
shutil.copy2(target, dst)
643+
object_mock.upload_file(target, **kwargs)
619644

620645
self.sagemaker_session.boto_session.resource().Object = MockS3Object
621646
return dst

tox.ini

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ exclude =
1919
.tox
2020
tests/data/
2121
venv/
22+
env/
2223

2324
max-complexity = 10
2425

0 commit comments

Comments
 (0)