From d7eccc6ab250683ea7b35b1dafad74d2a9ec1adb Mon Sep 17 00:00:00 2001 From: Ignacio Quintero Date: Thu, 12 Apr 2018 14:58:17 -0700 Subject: [PATCH 1/2] Fix local mode not using the right s3 bucket. Local Mode should honor the inputs instead of wrongly assuming that everyone is using the default bucket. --- src/sagemaker/local/image.py | 4 ++-- tests/unit/test_image.py | 24 +++++++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 423ea154df..35afb8a47d 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -87,7 +87,6 @@ def train(self, input_data_config, hyperparameters): os.mkdir(os.path.join(self.container_root, 'output')) data_dir = self._create_tmp_folder() - bucket_name = self.sagemaker_session.default_bucket() volumes = [] # Set up the channels for the containers. For local data we will @@ -102,7 +101,8 @@ def train(self, input_data_config, hyperparameters): channel_dir = os.path.join(data_dir, channel_name) os.mkdir(channel_dir) - if uri.lower().startswith("s3://"): + if parsed_uri.scheme == 's3': + bucket_name = parsed_uri.netloc self._download_folder(bucket_name, key, channel_dir) else: volumes.append(_Volume(uri, channel=channel_name)) diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 1fd82f0ab6..3b83083f1f 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -16,7 +16,7 @@ import pytest import yaml -from mock import patch, Mock +from mock import call, patch, Mock, ANY import sagemaker from sagemaker.local.image import _SageMakerContainer @@ -40,7 +40,7 @@ 'S3DataSource': { 'S3DataDistributionType': 'FullyReplicated', 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://foo/bar' + 'S3Uri': 's3://my-own-bucket/prefix' } } } @@ -54,12 +54,12 @@ def sagemaker_session(): boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'} boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = [] - ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) + sms = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - ims.expand_role = Mock(return_value=EXPANDED_ROLE) + sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + sms.expand_role = Mock(return_value=EXPANDED_ROLE) - return ims + return sms @patch('sagemaker.local.local_session.LocalSession') @@ -181,7 +181,8 @@ def test_check_output(): @patch('sagemaker.local.local_session.LocalSession') @patch('sagemaker.local.image._execute_and_stream_output') @patch('sagemaker.local.image._SageMakerContainer._cleanup') -def test_train(LocalSession, _execute_and_stream_output, _cleanup, tmpdir, sagemaker_session): +@patch('sagemaker.local.image._SageMakerContainer._download_folder') +def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSession, tmpdir, sagemaker_session): with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', side_effect=[str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]): @@ -191,6 +192,15 @@ def test_train(LocalSession, _execute_and_stream_output, _cleanup, tmpdir, sagem sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session) sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS) + download_folder_calls = [] + for channel in INPUT_DATA_CONFIG: + s3_uri = channel['DataSource']['S3DataSource']['S3Uri'] + if 's3://' in s3_uri: + bucket, prefix = s3_uri.replace('s3://', '').split('/') + download_folder_calls.append(call(bucket, prefix, ANY)) + _download_folder.assert_called() + _download_folder.assert_has_calls(download_folder_calls) + docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml') call_args = _execute_and_stream_output.call_args[0][0] From 57b82895f14010cbc0806a2855de76bdf736b2e2 Mon Sep 17 00:00:00 2001 From: Ignacio Quintero Date: Fri, 13 Apr 2018 10:49:56 -0700 Subject: [PATCH 2/2] Add a unit test for _download_folder() in image.py Also added the changelog entry --- CHANGELOG.rst | 4 ++++ tests/unit/test_image.py | 44 ++++++++++++++++++++++++++++++++-------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 884370e68d..d5e2ae4aba 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,10 @@ CHANGELOG ========= +1.2.3-dev +========= +* bug-fix: Fix local mode not using the right s3 bucket + 1.2.2 ===== diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 3b83083f1f..e23ca28036 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -16,7 +16,7 @@ import pytest import yaml -from mock import call, patch, Mock, ANY +from mock import call, patch, Mock import sagemaker from sagemaker.local.image import _SageMakerContainer @@ -184,21 +184,17 @@ def test_check_output(): @patch('sagemaker.local.image._SageMakerContainer._download_folder') def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSession, tmpdir, sagemaker_session): + directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))] with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - side_effect=[str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]): + side_effect=directories): instance_count = 2 image = 'my-image' sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session) sagemaker_container.train(INPUT_DATA_CONFIG, HYPERPARAMETERS) - download_folder_calls = [] - for channel in INPUT_DATA_CONFIG: - s3_uri = channel['DataSource']['S3DataSource']['S3Uri'] - if 's3://' in s3_uri: - bucket, prefix = s3_uri.replace('s3://', '').split('/') - download_folder_calls.append(call(bucket, prefix, ANY)) - _download_folder.assert_called() + channel_dir = os.path.join(directories[1], 'b') + download_folder_calls = [call('my-own-bucket', 'prefix', channel_dir)] _download_folder.assert_has_calls(download_folder_calls) docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml') @@ -241,6 +237,36 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session): assert config['services'][h]['command'] == 'serve' +@patch('os.makedirs') +def test_download_folder(makedirs): + boto_mock = Mock(name='boto_session') + boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'} + + session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) + + train_data = Mock() + validation_data = Mock() + + train_data.bucket_name.return_value = BUCKET_NAME + train_data.key = '/prefix/train/train_data.csv' + validation_data.bucket_name.return_value = BUCKET_NAME + validation_data.key = '/prefix/train/validation_data.csv' + + s3_files = [train_data, validation_data] + boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = s3_files + + obj_mock = Mock() + boto_mock.resource('s3').Object.return_value = obj_mock + + sagemaker_container = _SageMakerContainer('local', 2, 'my-image', sagemaker_session=session) + sagemaker_container._download_folder(BUCKET_NAME, '/prefix', '/tmp') + + obj_mock.download_file.assert_called() + calls = [call(os.path.join('/tmp', 'train/train_data.csv')), + call(os.path.join('/tmp', 'train/validation_data.csv'))] + obj_mock.download_file.assert_has_calls(calls) + + def test_ecr_login_non_ecr(): session_mock = Mock() sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu')