Skip to content

Commit d76cd2b

Browse files
authored
Fix local mode not using the right s3 bucket. (#144)
* Fix local mode not using the right s3 bucket. Local Mode should honor the inputs instead of wrongly assuming that everyone is using the default bucket.
1 parent 524dc86 commit d76cd2b

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

CHANGELOG.rst

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
CHANGELOG
33
=========
44

5+
1.2.3-dev
6+
=========
7+
* bug-fix: Fix local mode not using the right s3 bucket
8+
59
1.2.2
610
=====
711

src/sagemaker/local/image.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def train(self, input_data_config, hyperparameters):
8787
os.mkdir(os.path.join(self.container_root, 'output'))
8888

8989
data_dir = self._create_tmp_folder()
90-
bucket_name = self.sagemaker_session.default_bucket()
9190
volumes = []
9291

9392
# Set up the channels for the containers. For local data we will
@@ -102,7 +101,8 @@ def train(self, input_data_config, hyperparameters):
102101
channel_dir = os.path.join(data_dir, channel_name)
103102
os.mkdir(channel_dir)
104103

105-
if uri.lower().startswith("s3://"):
104+
if parsed_uri.scheme == 's3':
105+
bucket_name = parsed_uri.netloc
106106
self._download_folder(bucket_name, key, channel_dir)
107107
else:
108108
volumes.append(_Volume(uri, channel=channel_name))

tests/unit/test_image.py

+44-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818
import yaml
19-
from mock import patch, Mock
19+
from mock import call, patch, Mock
2020

2121
import sagemaker
2222
from sagemaker.local.image import _SageMakerContainer
@@ -40,7 +40,7 @@
4040
'S3DataSource': {
4141
'S3DataDistributionType': 'FullyReplicated',
4242
'S3DataType': 'S3Prefix',
43-
'S3Uri': 's3://foo/bar'
43+
'S3Uri': 's3://my-own-bucket/prefix'
4444
}
4545
}
4646
}
@@ -54,12 +54,12 @@ def sagemaker_session():
5454
boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'}
5555
boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = []
5656

57-
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
57+
sms = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
5858

59-
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
60-
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
59+
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
60+
sms.expand_role = Mock(return_value=EXPANDED_ROLE)
6161

62-
return ims
62+
return sms
6363

6464

6565
@patch('sagemaker.local.local_session.LocalSession')
@@ -181,16 +181,22 @@ def test_check_output():
181181
@patch('sagemaker.local.local_session.LocalSession')
182182
@patch('sagemaker.local.image._execute_and_stream_output')
183183
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
184-
def test_train(LocalSession, _execute_and_stream_output, _cleanup, tmpdir, sagemaker_session):
184+
@patch('sagemaker.local.image._SageMakerContainer._download_folder')
185+
def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSession, tmpdir, sagemaker_session):
185186

187+
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
186188
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
187-
side_effect=[str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]):
189+
side_effect=directories):
188190

189191
instance_count = 2
190192
image = 'my-image'
191193
sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session)
192194
sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS)
193195

196+
channel_dir = os.path.join(directories[1], 'b')
197+
download_folder_calls = [call('my-own-bucket', 'prefix', channel_dir)]
198+
_download_folder.assert_has_calls(download_folder_calls)
199+
194200
docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml')
195201

196202
call_args = _execute_and_stream_output.call_args[0][0]
@@ -231,6 +237,36 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
231237
assert config['services'][h]['command'] == 'serve'
232238

233239

240+
@patch('os.makedirs')
241+
def test_download_folder(makedirs):
242+
boto_mock = Mock(name='boto_session')
243+
boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'}
244+
245+
session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
246+
247+
train_data = Mock()
248+
validation_data = Mock()
249+
250+
train_data.bucket_name.return_value = BUCKET_NAME
251+
train_data.key = '/prefix/train/train_data.csv'
252+
validation_data.bucket_name.return_value = BUCKET_NAME
253+
validation_data.key = '/prefix/train/validation_data.csv'
254+
255+
s3_files = [train_data, validation_data]
256+
boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = s3_files
257+
258+
obj_mock = Mock()
259+
boto_mock.resource('s3').Object.return_value = obj_mock
260+
261+
sagemaker_container = _SageMakerContainer('local', 2, 'my-image', sagemaker_session=session)
262+
sagemaker_container._download_folder(BUCKET_NAME, '/prefix', '/tmp')
263+
264+
obj_mock.download_file.assert_called()
265+
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
266+
call(os.path.join('/tmp', 'train/validation_data.csv'))]
267+
obj_mock.download_file.assert_has_calls(calls)
268+
269+
234270
def test_ecr_login_non_ecr():
235271
session_mock = Mock()
236272
sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu')

0 commit comments

Comments
 (0)