Skip to content

Commit fb4f105

Browse files
nickolasjwilsonpengk19
authored andcommitted
feature: add encryption option to "record_set" (aws#794)
* feature: add encryption option to "record_set"
1 parent aec0887 commit fb4f105

File tree

3 files changed

+85
-10
lines changed

3 files changed

+85
-10
lines changed

src/sagemaker/amazon/amazon_estimator.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None
159159
if wait:
160160
self.latest_training_job.wait(logs=logs)
161161

162-
def record_set(self, train, labels=None, channel="train"):
162+
def record_set(self, train, labels=None, channel="train", encrypt=False):
163163
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
164164
165165
For the 2D ``ndarray`` ``train``, each row is converted to a :class:`~Record` object.
@@ -177,8 +177,10 @@ def record_set(self, train, labels=None, channel="train"):
177177
Args:
178178
train (numpy.ndarray): A 2D numpy array of training data.
179179
labels (numpy.ndarray): A 1D numpy array of labels. Its length must be equal to the
180-
number of rows in ``train``.
180+
number of rows in ``train``.
181181
channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to.
182+
encrypt (bool): Specifies whether the objects uploaded to S3 are encrypted on the
183+
server side using AES-256 (default: ``False``).
182184
Returns:
183185
RecordSet: A RecordSet referencing the encoded, uploading training and label data.
184186
"""
@@ -188,7 +190,8 @@ def record_set(self, train, labels=None, channel="train"):
188190
key_prefix = key_prefix + '{}-{}/'.format(type(self).__name__, sagemaker_timestamp())
189191
key_prefix = key_prefix.lstrip('/')
190192
logger.debug('Uploading to bucket {} and key_prefix {}'.format(bucket, key_prefix))
191-
manifest_s3_file = upload_numpy_to_s3_shards(self.train_instance_count, s3, bucket, key_prefix, train, labels)
193+
manifest_s3_file = upload_numpy_to_s3_shards(self.train_instance_count, s3, bucket,
194+
key_prefix, train, labels, encrypt)
192195
logger.debug("Created manifest file {}".format(manifest_s3_file))
193196
return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel)
194197

@@ -239,15 +242,17 @@ def _build_shards(num_shards, array):
239242
return shards
240243

241244

242-
def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None):
243-
"""Upload the training ``array`` and ``labels`` arrays to ``num_shards`` s3 objects,
244-
stored in "s3://``bucket``/``key_prefix``/"."""
245+
def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False):
246+
"""Upload the training ``array`` and ``labels`` arrays to ``num_shards`` S3 objects,
247+
stored in "s3://``bucket``/``key_prefix``/". Optionally ``encrypt`` the S3 objects using
248+
AES-256."""
245249
shards = _build_shards(num_shards, array)
246250
if labels is not None:
247251
label_shards = _build_shards(num_shards, labels)
248252
uploaded_files = []
249253
if key_prefix[-1] != '/':
250254
key_prefix = key_prefix + '/'
255+
extra_put_kwargs = {'ServerSideEncryption': 'AES256'} if encrypt else {}
251256
try:
252257
for shard_index, shard in enumerate(shards):
253258
with tempfile.TemporaryFile() as file:
@@ -260,12 +265,12 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=
260265
file_name = "matrix_{}.pbr".format(shard_index_string)
261266
key = key_prefix + file_name
262267
logger.debug("Creating object {} in bucket {}".format(key, bucket))
263-
s3.Object(bucket, key).put(Body=file)
268+
s3.Object(bucket, key).put(Body=file, **extra_put_kwargs)
264269
uploaded_files.append(file_name)
265270
manifest_key = key_prefix + ".amazon.manifest"
266271
manifest_str = json.dumps(
267272
[{'prefix': 's3://{}/{}'.format(bucket, key_prefix)}] + uploaded_files)
268-
s3.Object(bucket, manifest_key).put(Body=manifest_str.encode('utf-8'))
273+
s3.Object(bucket, manifest_key).put(Body=manifest_str.encode('utf-8'), **extra_put_kwargs)
269274
return "s3://{}/{}".format(bucket, manifest_key)
270275
except Exception as ex: # pylint: disable=broad-except
271276
try:

tests/integ/test_record_set.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import gzip
16+
import os
17+
import pickle
18+
import sys
19+
20+
from six.moves.urllib.parse import urlparse
21+
22+
from sagemaker import KMeans
23+
from tests.integ import DATA_DIR
24+
25+
26+
def test_record_set(sagemaker_session):
27+
"""Test the method ``AmazonAlgorithmEstimatorBase.record_set``.
28+
29+
In particular, test that the objects uploaded to the S3 bucket are encrypted.
30+
"""
31+
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
32+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
33+
with gzip.open(data_path, 'rb') as file_object:
34+
train_set, _, _ = pickle.load(file_object, **pickle_args)
35+
kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
36+
train_instance_type='ml.c4.xlarge',
37+
k=10, sagemaker_session=sagemaker_session)
38+
record_set = kmeans.record_set(train_set[0][:100], encrypt=True)
39+
parsed_url = urlparse(record_set.s3_data)
40+
s3_client = sagemaker_session.boto_session.client('s3')
41+
head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip('/'))
42+
assert head['ServerSideEncryption'] == 'AES256'

tests/unit/test_amazon_estimator.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616
import pytest
17-
from mock import Mock, patch, call
17+
from mock import ANY, Mock, patch, call
1818

1919
# Use PCA as a test implementation of AmazonAlgorithmEstimator
2020
from sagemaker.amazon.pca import PCA
@@ -143,6 +143,22 @@ def test_prepare_for_training_list_no_train_channel(sagemaker_session):
143143
assert 'Must provide train channel.' in str(ex)
144144

145145

146+
def test_prepare_for_training_encrypt(sagemaker_session):
147+
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
148+
149+
train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]]
150+
labels = [99, 85, 87, 2]
151+
with patch('sagemaker.amazon.amazon_estimator.upload_numpy_to_s3_shards',
152+
return_value='manfiest_file') as mock_upload:
153+
pca.record_set(np.array(train), np.array(labels))
154+
pca.record_set(np.array(train), np.array(labels), encrypt=True)
155+
156+
def make_upload_call(encrypt):
157+
return call(ANY, ANY, ANY, ANY, ANY, ANY, encrypt)
158+
159+
mock_upload.assert_has_calls([make_upload_call(False), make_upload_call(True)])
160+
161+
146162
@patch('time.strftime', return_value=TIMESTAMP)
147163
def test_fit_ndarray(time, sagemaker_session):
148164
mock_s3 = Mock()
@@ -185,9 +201,21 @@ def test_upload_numpy_to_s3_shards():
185201
mock_s3 = Mock()
186202
mock_object = Mock()
187203
mock_s3.Object = Mock(return_value=mock_object)
204+
mock_put = mock_s3.Object.return_value.put
188205
array = np.array([[j for j in range(10)] for i in range(10)])
189206
labels = np.array([i for i in range(10)])
190-
upload_numpy_to_s3_shards(3, mock_s3, BUCKET_NAME, "key-prefix", array, labels)
207+
num_shards = 3
208+
num_objects = num_shards + 1 # Account for the manifest file.
209+
210+
def make_all_put_calls(**kwargs):
211+
return [call(Body=ANY, **kwargs) for i in range(num_objects)]
212+
213+
upload_numpy_to_s3_shards(num_shards, mock_s3, BUCKET_NAME, "key-prefix", array, labels)
191214
mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_0.pbr')])
192215
mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_1.pbr')])
193216
mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_2.pbr')])
217+
mock_put.assert_has_calls(make_all_put_calls())
218+
219+
mock_put.reset()
220+
upload_numpy_to_s3_shards(3, mock_s3, BUCKET_NAME, "key-prefix", array, labels, encrypt=True)
221+
mock_put.assert_has_calls(make_all_put_calls(ServerSideEncryption='AES256'))

0 commit comments

Comments
 (0)