@@ -159,7 +159,7 @@ def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None
159
159
if wait :
160
160
self .latest_training_job .wait (logs = logs )
161
161
162
- def record_set (self , train , labels = None , channel = "train" ):
162
+ def record_set (self , train , labels = None , channel = "train" , encrypt = False ):
163
163
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
164
164
165
165
For the 2D ``ndarray`` ``train``, each row is converted to a :class:`~Record` object.
@@ -177,8 +177,10 @@ def record_set(self, train, labels=None, channel="train"):
177
177
Args:
178
178
train (numpy.ndarray): A 2D numpy array of training data.
179
179
labels (numpy.ndarray): A 1D numpy array of labels. Its length must be equal to the
180
- number of rows in ``train``.
180
+ number of rows in ``train``.
181
181
channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to.
182
+ encrypt (bool): Specifies whether the objects uploaded to S3 are encrypted on the
183
+ server side using AES-256 (default: ``False``).
182
184
Returns:
183
185
RecordSet: A RecordSet referencing the encoded, uploading training and label data.
184
186
"""
@@ -188,7 +190,8 @@ def record_set(self, train, labels=None, channel="train"):
188
190
key_prefix = key_prefix + '{}-{}/' .format (type (self ).__name__ , sagemaker_timestamp ())
189
191
key_prefix = key_prefix .lstrip ('/' )
190
192
logger .debug ('Uploading to bucket {} and key_prefix {}' .format (bucket , key_prefix ))
191
- manifest_s3_file = upload_numpy_to_s3_shards (self .train_instance_count , s3 , bucket , key_prefix , train , labels )
193
+ manifest_s3_file = upload_numpy_to_s3_shards (self .train_instance_count , s3 , bucket ,
194
+ key_prefix , train , labels , encrypt )
192
195
logger .debug ("Created manifest file {}" .format (manifest_s3_file ))
193
196
return RecordSet (manifest_s3_file , num_records = train .shape [0 ], feature_dim = train .shape [1 ], channel = channel )
194
197
@@ -239,15 +242,17 @@ def _build_shards(num_shards, array):
239
242
return shards
240
243
241
244
242
- def upload_numpy_to_s3_shards (num_shards , s3 , bucket , key_prefix , array , labels = None ):
243
- """Upload the training ``array`` and ``labels`` arrays to ``num_shards`` s3 objects,
244
- stored in "s3://``bucket``/``key_prefix``/"."""
245
+ def upload_numpy_to_s3_shards (num_shards , s3 , bucket , key_prefix , array , labels = None , encrypt = False ):
246
+ """Upload the training ``array`` and ``labels`` arrays to ``num_shards`` S3 objects,
247
+ stored in "s3://``bucket``/``key_prefix``/". Optionally ``encrypt`` the S3 objects using
248
+ AES-256."""
245
249
shards = _build_shards (num_shards , array )
246
250
if labels is not None :
247
251
label_shards = _build_shards (num_shards , labels )
248
252
uploaded_files = []
249
253
if key_prefix [- 1 ] != '/' :
250
254
key_prefix = key_prefix + '/'
255
+ extra_put_kwargs = {'ServerSideEncryption' : 'AES256' } if encrypt else {}
251
256
try :
252
257
for shard_index , shard in enumerate (shards ):
253
258
with tempfile .TemporaryFile () as file :
@@ -260,12 +265,12 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=
260
265
file_name = "matrix_{}.pbr" .format (shard_index_string )
261
266
key = key_prefix + file_name
262
267
logger .debug ("Creating object {} in bucket {}" .format (key , bucket ))
263
- s3 .Object (bucket , key ).put (Body = file )
268
+ s3 .Object (bucket , key ).put (Body = file , ** extra_put_kwargs )
264
269
uploaded_files .append (file_name )
265
270
manifest_key = key_prefix + ".amazon.manifest"
266
271
manifest_str = json .dumps (
267
272
[{'prefix' : 's3://{}/{}' .format (bucket , key_prefix )}] + uploaded_files )
268
- s3 .Object (bucket , manifest_key ).put (Body = manifest_str .encode ('utf-8' ))
273
+ s3 .Object (bucket , manifest_key ).put (Body = manifest_str .encode ('utf-8' ), ** extra_put_kwargs )
269
274
return "s3://{}/{}" .format (bucket , manifest_key )
270
275
except Exception as ex : # pylint: disable=broad-except
271
276
try :
0 commit comments