Skip to content

feat: default repack encryption #2821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ venv/
*.swp
.docker/
env/
.vscode/
.vscode/
.python-version
1 change: 1 addition & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2343,6 +2343,7 @@ def _stage_user_code_in_s3(self):
dependencies=self.dependencies,
kms_key=kms_key,
s3_resource=self.sagemaker_session.s3_resource,
settings=self.sagemaker_session.settings,
)

def _model_source_dir(self):
Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import shutil
import tempfile
from collections import namedtuple
from typing import Optional

import sagemaker.image_uris
from sagemaker.session_settings import SessionSettings
import sagemaker.utils

from sagemaker.deprecations import renamed_warning
Expand Down Expand Up @@ -203,6 +205,7 @@ def tar_and_upload_dir(
dependencies=None,
kms_key=None,
s3_resource=None,
settings: Optional[SessionSettings] = None,
):
"""Package source files and upload a compress tar file to S3.

Expand Down Expand Up @@ -230,6 +233,9 @@ def tar_and_upload_dir(
s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
for S3 connections, can be used to customize the configuration,
e.g. set the endpoint URL (default: None).
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
of the SageMaker ``Session``, can be used to override the default encryption
behavior (default: None).
Returns:
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
script name.
Expand All @@ -241,6 +247,7 @@ def tar_and_upload_dir(
dependencies = dependencies or []
key = "%s/sourcedir.tar.gz" % s3_key_prefix
tmp = tempfile.mkdtemp()
encrypt_artifact = True if settings is None else settings.encrypt_repacked_artifacts

try:
source_files = _list_files_to_compress(script, directory) + dependencies
Expand All @@ -250,6 +257,10 @@ def tar_and_upload_dir(

if kms_key:
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
elif encrypt_artifact:
# encrypt the tarball at rest in S3 with the default AWS managed KMS key for S3
# see https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html#API_PutObject_RequestSyntax
extra_args = {"ServerSideEncryption": "aws:kms"}
else:
extra_args = None

Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ def _upload_code(self, key_prefix, repack=False):
script=self.entry_point,
directory=self.source_dir,
dependencies=self.dependencies,
settings=self.sagemaker_session.settings,
)

if repack and self.model_data is not None and self.entry_point is not None:
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
sts_regional_endpoint,
)
from sagemaker import exceptions
from sagemaker.session_settings import SessionSettings

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
sagemaker_runtime_client=None,
sagemaker_featurestore_runtime_client=None,
default_bucket=None,
settings=SessionSettings(),
):
"""Initialize a SageMaker ``Session``.

Expand All @@ -110,13 +112,16 @@ def __init__(
If not provided, a default bucket will be created based on the following format:
"sagemaker-{region}-{aws-account-id}".
Example: "sagemaker-my-custom-bucket".
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
parameters to apply to the session.
"""
self._default_bucket = None
self._default_bucket_name_override = default_bucket
self.s3_resource = None
self.s3_client = None
self.config = None
self.lambda_client = None
self.settings = settings

self._initialize(
boto_session=boto_session,
Expand Down
34 changes: 34 additions & 0 deletions src/sagemaker/session_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Defines classes to parametrize a SageMaker ``Session``."""

from __future__ import absolute_import


class SessionSettings(object):
"""Optional container class for settings to apply to a SageMaker session."""

def __init__(self, encrypt_repacked_artifacts=True) -> None:
"""Initialize the ``SessionSettings`` of a SageMaker ``Session``.

Args:
encrypt_repacked_artifacts (bool): Flag to indicate whether to encrypt the artifacts
at rest in S3 using the default AWS managed KMS key for S3 when a custom KMS key
is not provided (Default: True).
"""
self._encrypt_repacked_artifacts = encrypt_repacked_artifacts

@property
def encrypt_repacked_artifacts(self) -> bool:
"""Return True if repacked artifacts at rest in S3 should be encrypted by default."""
return self._encrypt_repacked_artifacts
8 changes: 8 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from six.moves.urllib import parse

from sagemaker import deprecations
from sagemaker.session_settings import SessionSettings


ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
Expand Down Expand Up @@ -429,8 +430,15 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
bucket, key = url.netloc, url.path.lstrip("/")
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))

settings = (
sagemaker_session.settings if sagemaker_session is not None else SessionSettings()
)
encrypt_artifact = settings.encrypt_repacked_artifacts

if kms_key:
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
elif encrypt_artifact:
extra_args = {"ServerSideEncryption": "aws:kms"}
else:
extra_args = None
sagemaker_session.boto_session.resource(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2323,8 +2323,8 @@ def test_different_code_location_kms_key(utils, sagemaker_session):
obj = sagemaker_session.boto_session.resource("s3").Object

obj.assert_called_with("another-location", "%s/source/sourcedir.tar.gz" % fw._current_job_name)

obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None)
extra_args = {"ServerSideEncryption": "aws:kms"}
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)


@patch("sagemaker.utils")
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from sagemaker import fw_utils
from sagemaker.utils import name_from_image
from sagemaker.session_settings import SessionSettings

TIMESTAMP = "2017-10-10-14-14-15"

Expand Down Expand Up @@ -93,6 +94,40 @@ def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)


@patch("sagemaker.utils")
def test_tar_and_upload_dir_s3_kms_enabled_by_default(utils, sagemaker_session):
bucket = "mybucket"
s3_key_prefix = "something/source"
script = "inference.py"
result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script)

assert result == fw_utils.UploadedCode(
"s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
)

extra_args = {"ServerSideEncryption": "aws:kms"}
obj = sagemaker_session.resource("s3").Object("", "")
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)


@patch("sagemaker.utils")
def test_tar_and_upload_dir_s3_without_kms_with_overridden_settings(utils, sagemaker_session):
bucket = "mybucket"
s3_key_prefix = "something/source"
script = "inference.py"
settings = SessionSettings(encrypt_repacked_artifacts=False)
result = fw_utils.tar_and_upload_dir(
sagemaker_session, bucket, s3_key_prefix, script, settings=settings
)

assert result == fw_utils.UploadedCode(
"s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script
)

obj = sagemaker_session.resource("s3").Object("", "")
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None)


def test_mp_config_partition_exists():
mp_parameters = {}
with pytest.raises(ValueError):
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from mock import call, patch, Mock, MagicMock

import sagemaker
from sagemaker.session_settings import SessionSettings

BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"

Expand Down Expand Up @@ -390,6 +391,13 @@ def test_repack_model_without_source_dir(tmp, fake_s3):
"/code/inference.py",
}

extra_args = {"ServerSideEncryption": "aws:kms"}
object_mock = fake_s3.object_mock
_, _, kwargs = object_mock.mock_calls[0]

assert "ExtraArgs" in kwargs
assert kwargs["ExtraArgs"] == extra_args


def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake_s3):

Expand All @@ -415,12 +423,20 @@ def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake
"s3://fake/location",
"s3://destination-bucket/model.tar.gz",
fake_s3.sagemaker_session,
kms_key="kms_key",
)
finally:
os.chdir(cwd)

assert list_tar_files(fake_s3.fake_upload_path, tmp) == {"/code/inference.py", "/model"}

extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": "kms_key"}
object_mock = fake_s3.object_mock
_, _, kwargs = object_mock.mock_calls[0]

assert "ExtraArgs" in kwargs
assert kwargs["ExtraArgs"] == extra_args


def test_repack_model_from_s3_to_s3(tmp, fake_s3):

Expand All @@ -434,6 +450,7 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
)

fake_s3.tar_and_upload("model-dir", "s3://fake/location")
fake_s3.sagemaker_session.settings = SessionSettings(encrypt_repacked_artifacts=False)

sagemaker.utils.repack_model(
"inference.py",
Expand All @@ -450,6 +467,11 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
"/model",
}

object_mock = fake_s3.object_mock
_, _, kwargs = object_mock.mock_calls[0]
assert "ExtraArgs" in kwargs
assert kwargs["ExtraArgs"] is None


def test_repack_model_from_file_to_file(tmp):
create_file_tree(tmp, ["model", "dependencies/a", "source-dir/inference.py"])
Expand Down Expand Up @@ -581,6 +603,7 @@ def __init__(self, tmp):
self.sagemaker_session = MagicMock()
self.location_map = {}
self.current_bucket = None
self.object_mock = MagicMock()

self.sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = (
self.download_file
Expand All @@ -606,6 +629,7 @@ def tar_and_upload(self, path, fake_location):

def mock_s3_upload(self):
dst = os.path.join(self.tmp, "dst")
object_mock = self.object_mock

class MockS3Object(object):
def __init__(self, bucket, key):
Expand All @@ -616,6 +640,7 @@ def upload_file(self, target, **kwargs):
if self.bucket in BUCKET_WITHOUT_WRITING_PERMISSION:
raise exceptions.S3UploadFailedError()
shutil.copy2(target, dst)
object_mock.upload_file(target, **kwargs)

self.sagemaker_session.boto_session.resource().Object = MockS3Object
return dst
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ exclude =
.tox
tests/data/
venv/
env/

max-complexity = 10

Expand Down