Skip to content

Commit 93919bf

Browse files
drajeshkpengk19
authored andcommitted
feature: Add extra_args to enable encrypted objects upload (aws#836)
1 parent 337e96c commit 93919bf

File tree

3 files changed

+87
-10
lines changed

3 files changed

+87
-10
lines changed

src/sagemaker/session.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
115115
def boto_region_name(self):
116116
return self._region_name
117117

118-
def upload_data(self, path, bucket=None, key_prefix='data'):
118+
def upload_data(self, path, bucket=None, key_prefix='data', extra_args=None):
119119
"""Upload local file or directory to S3.
120120
121121
If a single file is specified for upload, the resulting S3 object key is ``{key_prefix}/{filename}``
@@ -132,6 +132,10 @@ def upload_data(self, path, bucket=None, key_prefix='data'):
132132
creates it).
133133
key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the prefix to
134134
create a directory structure for the bucket content that it display in the S3 console.
135+
extra_args (dict): Optional extra arguments that may be passed to the upload operation. Similar to
136+
ExtraArgs parameter in S3 upload_file function. Please refer to the ExtraArgs parameter
137+
documentation here:
138+
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
135139
136140
Returns:
137141
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument, the URI format is:
@@ -158,7 +162,7 @@ def upload_data(self, path, bucket=None, key_prefix='data'):
158162
s3 = self.boto_session.resource('s3')
159163

160164
for local_path, s3_key in files:
161-
s3.Object(bucket, s3_key).upload_file(local_path)
165+
s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args)
162166

163167
s3_uri = 's3://{}/{}'.format(bucket, key_prefix)
164168
# If a specific file was used as input (instead of a directory), we return the full S3 key

tests/integ/test_data_upload.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
from six.moves.urllib.parse import urlparse
18+
19+
from tests.integ import DATA_DIR
20+
21+
AES_ENCRYPTION_ENABLED = {'ServerSideEncryption': 'AES256'}
22+
23+
24+
def test_upload_data_absolute_file(sagemaker_session):
25+
"""Test the method ``Session.upload_data`` can upload one encrypted file to S3 bucket"""
26+
data_path = os.path.join(DATA_DIR, 'upload_data_tests', 'file1.py')
27+
uploaded_file = sagemaker_session.upload_data(data_path, extra_args=AES_ENCRYPTION_ENABLED)
28+
parsed_url = urlparse(uploaded_file)
29+
s3_client = sagemaker_session.boto_session.client('s3')
30+
head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip('/'))
31+
assert head['ServerSideEncryption'] == 'AES256'
32+
33+
34+
def test_upload_data_absolute_dir(sagemaker_session):
35+
"""Test the method ``Session.upload_data`` can upload encrypted objects to S3 bucket"""
36+
data_path = os.path.join(DATA_DIR, 'upload_data_tests', 'nested_dir')
37+
uploaded_dir = sagemaker_session.upload_data(data_path, extra_args=AES_ENCRYPTION_ENABLED)
38+
parsed_url = urlparse(uploaded_dir)
39+
s3_bucket = parsed_url.netloc
40+
s3_prefix = parsed_url.path.lstrip('/')
41+
s3_client = sagemaker_session.boto_session.client('s3')
42+
for file in os.listdir(data_path):
43+
s3_key = '{}/{}'.format(s3_prefix, file)
44+
head = s3_client.head_object(Bucket=s3_bucket, Key=s3_key)
45+
assert head['ServerSideEncryption'] == 'AES256'

tests/unit/test_upload_data.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SINGLE_FILE_NAME = 'file1.py'
2525
UPLOAD_DATA_TESTS_SINGLE_FILE = os.path.join(UPLOAD_DATA_TESTS_FILES_DIR, SINGLE_FILE_NAME)
2626
BUCKET_NAME = 'mybucket'
27+
AES_ENCRYPTION_ENABLED = {'ServerSideEncryption': 'AES256'}
2728

2829

2930
@pytest.fixture()
@@ -37,19 +38,46 @@ def sagemaker_session():
3738
def test_upload_data_absolute_dir(sagemaker_session):
3839
result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_FILES_DIR)
3940

40-
uploaded_files = [args[0] for name, args, kwargs in sagemaker_session.boto_session.mock_calls
41-
if name == 'resource().Object().upload_file']
41+
uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls
42+
if name == 'resource().Object().upload_file']
4243
assert result_s3_uri == 's3://{}/data'.format(BUCKET_NAME)
43-
assert len(uploaded_files) == 4
44-
for file in uploaded_files:
44+
assert len(uploaded_files_with_args) == 4
45+
for file, kwargs in uploaded_files_with_args:
4546
assert os.path.exists(file)
47+
assert kwargs['ExtraArgs'] is None
4648

4749

4850
def test_upload_data_absolute_file(sagemaker_session):
4951
result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_SINGLE_FILE)
5052

51-
uploaded_files = [args[0] for name, args, kwargs in sagemaker_session.boto_session.mock_calls
52-
if name == 'resource().Object().upload_file']
53+
uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls
54+
if name == 'resource().Object().upload_file']
5355
assert result_s3_uri == 's3://{}/data/{}'.format(BUCKET_NAME, SINGLE_FILE_NAME)
54-
assert len(uploaded_files) == 1
55-
assert os.path.exists(uploaded_files[0])
56+
assert len(uploaded_files_with_args) == 1
57+
(file, kwargs) = uploaded_files_with_args[0]
58+
assert os.path.exists(file)
59+
assert kwargs['ExtraArgs'] is None
60+
61+
62+
def test_upload_data_aes_encrypted_absolute_dir(sagemaker_session):
63+
result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_FILES_DIR, extra_args=AES_ENCRYPTION_ENABLED)
64+
65+
uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls
66+
if name == 'resource().Object().upload_file']
67+
assert result_s3_uri == 's3://{}/data'.format(BUCKET_NAME)
68+
assert len(uploaded_files_with_args) == 4
69+
for file, kwargs in uploaded_files_with_args:
70+
assert os.path.exists(file)
71+
assert kwargs['ExtraArgs'] == AES_ENCRYPTION_ENABLED
72+
73+
74+
def test_upload_data_aes_encrypted_absolute_file(sagemaker_session):
75+
result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_SINGLE_FILE, extra_args=AES_ENCRYPTION_ENABLED)
76+
77+
uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls
78+
if name == 'resource().Object().upload_file']
79+
assert result_s3_uri == 's3://{}/data/{}'.format(BUCKET_NAME, SINGLE_FILE_NAME)
80+
assert len(uploaded_files_with_args) == 1
81+
(file, kwargs) = uploaded_files_with_args[0]
82+
assert os.path.exists(file)
83+
assert kwargs['ExtraArgs'] == AES_ENCRYPTION_ENABLED

0 commit comments

Comments
 (0)