16
16
17
17
import pytest
18
18
import yaml
19
- from mock import patch , Mock
19
+ from mock import call , patch , Mock
20
20
21
21
import sagemaker
22
22
from sagemaker .local .image import _SageMakerContainer
40
40
'S3DataSource' : {
41
41
'S3DataDistributionType' : 'FullyReplicated' ,
42
42
'S3DataType' : 'S3Prefix' ,
43
- 'S3Uri' : 's3://foo/bar '
43
+ 'S3Uri' : 's3://my-own-bucket/prefix '
44
44
}
45
45
}
46
46
}
@@ -54,12 +54,12 @@ def sagemaker_session():
54
54
boto_mock .client ('sts' ).get_caller_identity .return_value = {'Account' : '123' }
55
55
boto_mock .resource ('s3' ).Bucket (BUCKET_NAME ).objects .filter .return_value = []
56
56
57
- ims = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
57
+ sms = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
58
58
59
- ims .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
60
- ims .expand_role = Mock (return_value = EXPANDED_ROLE )
59
+ sms .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
60
+ sms .expand_role = Mock (return_value = EXPANDED_ROLE )
61
61
62
- return ims
62
+ return sms
63
63
64
64
65
65
@patch ('sagemaker.local.local_session.LocalSession' )
@@ -181,16 +181,22 @@ def test_check_output():
181
181
@patch ('sagemaker.local.local_session.LocalSession' )
182
182
@patch ('sagemaker.local.image._execute_and_stream_output' )
183
183
@patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
184
- def test_train (LocalSession , _execute_and_stream_output , _cleanup , tmpdir , sagemaker_session ):
184
+ @patch ('sagemaker.local.image._SageMakerContainer._download_folder' )
185
+ def test_train (_download_folder , _cleanup , _execute_and_stream_output , LocalSession , tmpdir , sagemaker_session ):
185
186
187
+ directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
186
188
with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
187
- side_effect = [ str ( tmpdir . mkdir ( 'container-root' )), str ( tmpdir . mkdir ( 'data' ))] ):
189
+ side_effect = directories ):
188
190
189
191
instance_count = 2
190
192
image = 'my-image'
191
193
sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = sagemaker_session )
192
194
sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS )
193
195
196
+ channel_dir = os .path .join (directories [1 ], 'b' )
197
+ download_folder_calls = [call ('my-own-bucket' , 'prefix' , channel_dir )]
198
+ _download_folder .assert_has_calls (download_folder_calls )
199
+
194
200
docker_compose_file = os .path .join (sagemaker_container .container_root , 'docker-compose.yaml' )
195
201
196
202
call_args = _execute_and_stream_output .call_args [0 ][0 ]
@@ -231,6 +237,36 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
231
237
assert config ['services' ][h ]['command' ] == 'serve'
232
238
233
239
240
+ @patch ('os.makedirs' )
241
+ def test_download_folder (makedirs ):
242
+ boto_mock = Mock (name = 'boto_session' )
243
+ boto_mock .client ('sts' ).get_caller_identity .return_value = {'Account' : '123' }
244
+
245
+ session = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
246
+
247
+ train_data = Mock ()
248
+ validation_data = Mock ()
249
+
250
+ train_data .bucket_name .return_value = BUCKET_NAME
251
+ train_data .key = '/prefix/train/train_data.csv'
252
+ validation_data .bucket_name .return_value = BUCKET_NAME
253
+ validation_data .key = '/prefix/train/validation_data.csv'
254
+
255
+ s3_files = [train_data , validation_data ]
256
+ boto_mock .resource ('s3' ).Bucket (BUCKET_NAME ).objects .filter .return_value = s3_files
257
+
258
+ obj_mock = Mock ()
259
+ boto_mock .resource ('s3' ).Object .return_value = obj_mock
260
+
261
+ sagemaker_container = _SageMakerContainer ('local' , 2 , 'my-image' , sagemaker_session = session )
262
+ sagemaker_container ._download_folder (BUCKET_NAME , '/prefix' , '/tmp' )
263
+
264
+ obj_mock .download_file .assert_called ()
265
+ calls = [call (os .path .join ('/tmp' , 'train/train_data.csv' )),
266
+ call (os .path .join ('/tmp' , 'train/validation_data.csv' ))]
267
+ obj_mock .download_file .assert_has_calls (calls )
268
+
269
+
234
270
def test_ecr_login_non_ecr ():
235
271
session_mock = Mock ()
236
272
sagemaker .local .image ._ecr_login_if_needed (session_mock , 'ubuntu' )
0 commit comments