@@ -268,23 +268,57 @@ def test_download_file():
268
268
269
269
@patch ('tarfile.open' )
270
270
def test_create_tar_file_with_provided_path (open ):
271
- open .return_value = open
272
- open .__enter__ = Mock ()
273
- open .__exit__ = Mock (return_value = None )
271
+ files = mock_tarfile (open )
272
+
274
273
file_list = ['/tmp/a' , '/tmp/b' ]
274
+
275
275
path = sagemaker .utils .create_tar_file (file_list , target = '/my/custom/path.tar.gz' )
276
276
assert path == '/my/custom/path.tar.gz'
277
+ assert files == [['/tmp/a' , 'a' ], ['/tmp/b' , 'b' ]]
277
278
278
279
279
280
@patch ('tarfile.open' )
280
- @patch ('tempfile.mkstemp' , Mock (return_value = (None , '/auto/generated/path' )))
281
- def test_create_tar_file_with_auto_generated_path (open ):
281
+ def test_create_tar_file_with_directories (open ):
282
+ files = mock_tarfile (open )
283
+
284
+ path = sagemaker .utils .create_tar_file (dir_files = ['/tmp/a' , '/tmp/b' ],
285
+ target = '/my/custom/path.tar.gz' )
286
+ assert path == '/my/custom/path.tar.gz'
287
+ assert files == [['/tmp/a' , '/' ], ['/tmp/b' , '/' ]]
288
+
289
+
290
+ @patch ('tarfile.open' )
291
+ def test_create_tar_file_with_files_and_directories (open ):
292
+ files = mock_tarfile (open )
293
+
294
+ path = sagemaker .utils .create_tar_file (dir_files = ['/tmp/a' , '/tmp/b' ],
295
+ source_files = ['/tmp/c' , '/tmp/d' ],
296
+ target = '/my/custom/path.tar.gz' )
297
+ assert path == '/my/custom/path.tar.gz'
298
+ assert files == [['/tmp/c' , 'c' ], ['/tmp/d' , 'd' ], ['/tmp/a' , '/' ], ['/tmp/b' , '/' ]]
299
+
300
+
301
+ def mock_tarfile (open ):
282
302
open .return_value = open
303
+ files = []
304
+
305
+ def add_files (filename , arcname ):
306
+ files .append ([filename , arcname ])
307
+
283
308
open .__enter__ = Mock ()
309
+ open .__enter__ ().add = add_files
284
310
open .__exit__ = Mock (return_value = None )
285
- file_list = ['/tmp/a' , '/tmp/b' ]
286
- path = sagemaker .utils .create_tar_file (file_list )
311
+ return files
312
+
313
+
314
+ @patch ('tarfile.open' )
315
+ @patch ('tempfile.mkstemp' , Mock (return_value = (None , '/auto/generated/path' )))
316
+ def test_create_tar_file_with_auto_generated_path (open ):
317
+ files = mock_tarfile (open )
318
+
319
+ path = sagemaker .utils .create_tar_file (['/tmp/a' , '/tmp/b' ])
287
320
assert path == '/auto/generated/path'
321
+ assert files == [['/tmp/a' , 'a' ], ['/tmp/b' , 'b' ]]
288
322
289
323
290
324
def write_file (path , content ):
0 commit comments