Skip to content

Commit 34bdbf7

Browse files
committed
pass kms id as parameter for uploading code with Server side encryption
1 parent 9c76287 commit 34bdbf7

File tree

7 files changed

+202
-4
lines changed

7 files changed

+202
-4
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
CHANGELOG
33
=========
44

5+
1.18.5.dev
6+
==========
7+
8+
* bug-fix: pass kms id as parameter for uploading code with Server side encryption
59

610
1.18.4
711
======

src/sagemaker/estimator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,18 +863,23 @@ def _stage_user_code_in_s3(self):
863863
864864
"""
865865
if self.code_location is None:
866-
code_bucket = self.sagemaker_session.default_bucket()
866+
code_bucket, _ = parse_s3_url(self.output_path)
867867
code_s3_prefix = '{}/source'.format(self._current_job_name)
868868
else:
869869
code_bucket, key_prefix = parse_s3_url(self.code_location)
870870
code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, 'source']))
871871

872+
output_bucket, _ = parse_s3_url(self.output_path)
873+
874+
kms_key = self.output_kms_key if code_bucket == output_bucket else None
875+
872876
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
873877
bucket=code_bucket,
874878
s3_key_prefix=code_s3_prefix,
875879
script=self.entry_point,
876880
directory=self.source_dir,
877-
dependencies=self.dependencies)
881+
dependencies=self.dependencies,
882+
kms_key=kms_key)
878883

879884
def _model_source_dir(self):
880885
"""Get the appropriate value to pass as source_dir to model constructor on deploying

src/sagemaker/fw_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def validate_source_dir(script, directory):
136136
return True
137137

138138

139-
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, dependencies=None):
139+
def tar_and_upload_dir(session, bucket, s3_key_prefix, script,
140+
directory=None, dependencies=None, kms_key=None):
140141
"""Package source files and upload a compress tar file to S3. The S3 location will be
141142
``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
142143
@@ -159,6 +160,7 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, depend
159160
dependencies (List[str]): Optional. A list of paths to directories (absolute or relative)
160161
containing additional libraries that will be copied into
161162
/opt/ml/lib
163+
kms_key (str): Optional. KMS key ID used to upload objects to the bucket (default: None).
162164
163165
Returns:
164166
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
@@ -177,7 +179,12 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, depend
177179
tar_file = sagemaker.utils.create_tar_file(source_files,
178180
os.path.join(tmp, _TAR_SOURCE_FILENAME))
179181

180-
session.resource('s3').Object(bucket, key).upload_file(tar_file)
182+
if kms_key:
183+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key}
184+
else:
185+
extra_args = None
186+
187+
session.resource('s3').Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
181188
finally:
182189
shutil.rmtree(tmp)
183190

tests/integ/kms_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from botocore import exceptions
16+
1517
KEY_ALIAS = "SageMakerKmsKey"
1618
KEY_POLICY = '''
1719
{{
@@ -89,3 +91,78 @@ def get_or_create_kms_key(kms_client, account_id):
8991
return kms_key_arn
9092
else:
9193
return _create_kms_key(kms_client, account_id)
94+
95+
96+
KMS_BUCKET_POLICY = """{
97+
"Version": "2012-10-17",
98+
"Id": "PutObjPolicy",
99+
"Statement": [
100+
{
101+
"Sid": "DenyIncorrectEncryptionHeader",
102+
"Effect": "Deny",
103+
"Principal": "*",
104+
"Action": "s3:PutObject",
105+
"Resource": "arn:aws:s3:::%s/*",
106+
"Condition": {
107+
"StringNotEquals": {
108+
"s3:x-amz-server-side-encryption": "aws:kms"
109+
}
110+
}
111+
},
112+
{
113+
"Sid": "DenyUnEncryptedObjectUploads",
114+
"Effect": "Deny",
115+
"Principal": "*",
116+
"Action": "s3:PutObject",
117+
"Resource": "arn:aws:s3:::%s/*",
118+
"Condition": {
119+
"Null": {
120+
"s3:x-amz-server-side-encryption": "true"
121+
}
122+
}
123+
}
124+
]
125+
}"""
126+
127+
128+
def get_or_create_bucket_with_encryption(boto_session):
129+
account = boto_session.client('sts').get_caller_identity()['Account']
130+
kms_key_arn = get_or_create_kms_key(boto_session.client('kms'), account)
131+
132+
region = boto_session.region_name
133+
bucket_name = 'sagemaker-{}-{}-with-kms'.format(region, account)
134+
135+
s3 = boto_session.client('s3')
136+
try:
137+
# 'us-east-1' cannot be specified because it is the default region:
138+
# https://github.com/boto/boto3/issues/125
139+
if region == 'us-east-1':
140+
s3.create_bucket(Bucket=bucket_name)
141+
else:
142+
s3.create_bucket(Bucket=bucket_name,
143+
CreateBucketConfiguration={'LocationConstraint': region})
144+
145+
except exceptions.ClientError as e:
146+
if e.response['Error']['Code'] != 'BucketAlreadyOwnedByYou':
147+
raise
148+
149+
s3.put_bucket_encryption(
150+
Bucket=bucket_name,
151+
ServerSideEncryptionConfiguration={
152+
'Rules': [
153+
{
154+
'ApplyServerSideEncryptionByDefault': {
155+
'SSEAlgorithm': 'aws:kms',
156+
'KMSMasterKeyID': kms_key_arn
157+
}
158+
},
159+
]
160+
}
161+
)
162+
163+
s3.put_bucket_policy(
164+
Bucket=bucket_name,
165+
Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)
166+
)
167+
168+
return 's3://' + bucket_name, kms_key_arn

tests/integ/test_tf_script_mode.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import time
17+
1618
import pytest
1719

1820
import boto3
1921
from sagemaker.tensorflow import TensorFlow
2022
from six.moves.urllib.parse import urlparse
2123
import tests.integ as integ
24+
from tests.integ import kms_utils
2225
import tests.integ.timeout as timeout
2326

2427
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'data', 'tensorflow_mnist')
@@ -52,6 +55,33 @@ def test_mnist(sagemaker_session, instance_type):
5255
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta', 'saved_model.pb'])
5356

5457

58+
def test_server_side_encryption(sagemaker_session):
59+
60+
bucket_with_kms, kms_key = kms_utils.get_or_create_bucket_with_encryption(sagemaker_session.boto_session)
61+
62+
output_path = os.path.join(bucket_with_kms, 'test-server-side-encryption', time.strftime('%y%m%d-%H%M'))
63+
64+
estimator = TensorFlow(entry_point=SCRIPT,
65+
role='SageMakerRole',
66+
train_instance_count=1,
67+
train_instance_type='ml.c5.xlarge',
68+
sagemaker_session=sagemaker_session,
69+
py_version='py3',
70+
framework_version='1.11',
71+
base_job_name='test-server-side-encryption',
72+
code_location=output_path,
73+
output_path=output_path,
74+
model_dir='/opt/ml/model',
75+
output_kms_key=kms_key)
76+
77+
inputs = estimator.sagemaker_session.upload_data(
78+
path=os.path.join(RESOURCE_PATH, 'data'),
79+
key_prefix='scriptmode/mnist')
80+
81+
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
82+
estimator.fit(inputs)
83+
84+
5585
@pytest.mark.canary_quick
5686
@pytest.mark.skipif(integ.PYTHON_VERSION != 'py3', reason="Script Mode tests are only configured to run with Python 3")
5787
def test_mnist_distributed(sagemaker_session, instance_type):

tests/unit/test_estimator.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,64 @@ def test_container_log_level(sagemaker_session):
761761
assert train_kwargs['hyperparameters']['sagemaker_container_log_level'] == '10'
762762

763763

764+
@patch('sagemaker.utils')
765+
def test_same_code_location_keeps_kms_key(utils, sagemaker_session):
766+
fw = DummyFramework(entry_point=SCRIPT_PATH,
767+
role='DummyRole',
768+
sagemaker_session=sagemaker_session,
769+
train_instance_count=INSTANCE_COUNT,
770+
train_instance_type=INSTANCE_TYPE,
771+
output_kms_key='kms-key')
772+
773+
fw.fit(wait=False)
774+
775+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
776+
obj = sagemaker_session.boto_session.resource('s3').Object
777+
778+
obj.assert_called_with('mybucket', '%s/source/sourcedir.tar.gz' % fw._current_job_name)
779+
780+
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
781+
782+
783+
@patch('sagemaker.utils')
784+
def test_different_code_location_kms_key(utils, sagemaker_session):
785+
fw = DummyFramework(entry_point=SCRIPT_PATH,
786+
role='DummyRole',
787+
sagemaker_session=sagemaker_session,
788+
code_location='s3://another-location',
789+
train_instance_count=INSTANCE_COUNT,
790+
train_instance_type=INSTANCE_TYPE,
791+
output_kms_key='kms-key')
792+
793+
fw.fit(wait=False)
794+
795+
obj = sagemaker_session.boto_session.resource('s3').Object
796+
797+
obj.assert_called_with('another-location', '%s/source/sourcedir.tar.gz' % fw._current_job_name)
798+
799+
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None)
800+
801+
802+
@patch('sagemaker.utils')
803+
def test_default_code_location_uses_output_path(utils, sagemaker_session):
804+
fw = DummyFramework(entry_point=SCRIPT_PATH,
805+
role='DummyRole',
806+
sagemaker_session=sagemaker_session,
807+
output_path='s3://output_path',
808+
train_instance_count=INSTANCE_COUNT,
809+
train_instance_type=INSTANCE_TYPE,
810+
output_kms_key='kms-key')
811+
812+
fw.fit(wait=False)
813+
814+
obj = sagemaker_session.boto_session.resource('s3').Object
815+
816+
obj.assert_called_with('output_path', '%s/source/sourcedir.tar.gz' % fw._current_job_name)
817+
818+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
819+
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
820+
821+
764822
def test_wait_without_logs(sagemaker_session):
765823
training_job = _TrainingJob(sagemaker_session, JOB_NAME)
766824

tests/unit/test_fw_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,23 @@ def test_tar_and_upload_dir_s3(sagemaker_session):
165165
assert result == fw_utils.UploadedCode('s3://m', 'mnist.py')
166166

167167

168+
@patch('sagemaker.utils')
169+
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
170+
171+
result = fw_utils.tar_and_upload_dir(sagemaker_session,
172+
'mybucker',
173+
'something/source',
174+
'mnist.py',
175+
kms_key='kms-key')
176+
177+
assert result == fw_utils.UploadedCode('s3://mybucker/something/source/sourcedir.tar.gz',
178+
'mnist.py')
179+
180+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
181+
obj = sagemaker_session.resource('s3').Object('', '')
182+
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
183+
184+
168185
def test_validate_source_dir_does_not_exits(sagemaker_session):
169186
script = 'mnist.py'
170187
directory = ' !@#$%^&*()path probably in not there.!@#$%^&*()'

0 commit comments

Comments
 (0)