Skip to content
This repository was archived by the owner on Aug 26, 2020. It is now read-only.

Commit 4f7e7cc

Browse files
wiltonwumvsusp
authored andcommitted
change: download_and_extract local tar file (#194)
1 parent 5b44621 commit 4f7e7cc

File tree

5 files changed

+57
-16
lines changed

5 files changed

+57
-16
lines changed

src/sagemaker_containers/_files.py

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def download_and_extract(uri, path): # type: (str, str) -> None
132132
if os.path.exists(path):
133133
shutil.rmtree(path)
134134
shutil.move(uri, path)
135+
elif tarfile.is_tarfile(uri):
136+
with tarfile.open(name=uri, mode='r:gz') as t:
137+
t.extractall(path=path)
135138
else:
136139
shutil.copy2(uri, path)
137140

test/__init__.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,28 @@ def add_file(self, file): # type: (File) -> UserModule
150150
def url(self): # type: () -> str
151151
return os.path.join('s3://', self.bucket, self.key)
152152

153-
def upload(self): # type: () -> UserModule
154-
with _files.tmpdir() as tmpdir:
155-
tar_name = os.path.join(tmpdir, 'sourcedir.tar.gz')
156-
with tarfile.open(tar_name, mode='w:gz') as tar:
157-
for _file in self._files:
158-
name = os.path.join(tmpdir, _file.name)
159-
with open(name, 'w+') as f:
153+
def create_tar(self, dir_path=None):
154+
dir_path = dir_path or os.path.dirname(os.path.realpath(__file__))
155+
tar_name = os.path.join(dir_path, 'sourcedir.tar.gz')
156+
with tarfile.open(tar_name, mode='w:gz') as tar:
157+
for _file in self._files:
158+
name = os.path.join(dir_path, _file.name)
159+
with open(name, 'w+') as f:
160+
161+
if isinstance(_file.data, six.string_types):
162+
data = _file.data
163+
else:
164+
data = '\n'.join(_file.data)
160165

161-
if isinstance(_file.data, six.string_types):
162-
data = _file.data
163-
else:
164-
data = '\n'.join(_file.data)
166+
f.write(data)
167+
tar.add(name=name, arcname=_file.name)
168+
os.remove(name)
165169

166-
f.write(data)
167-
tar.add(name=name, arcname=_file.name)
170+
return tar_name
168171

172+
def upload(self): # type: () -> UserModule
173+
with _files.tmpdir() as tmpdir:
174+
tar_name = self.create_tar(dir_path=tmpdir)
169175
self._s3.Object(self.bucket, self.key).upload_file(tar_name)
170176
return self
171177

test/functional/test_download_and_import.py

+15
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import importlib
16+
import os
1617
import shlex
1718
import subprocess
1819
import textwrap
@@ -157,3 +158,17 @@ def test_import_module_with_s3_script_with_error(user_module_name):
157158

158159
with pytest.raises(errors.ImportModuleError):
159160
modules.import_module(user_module.url, user_module_name, cache=False)
161+
162+
163+
@pytest.mark.parametrize('user_module',
164+
[test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(SETUP_FILE),
165+
test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS)])
166+
def test_import_module_with_local_tar_via_download_and_extract(user_module, user_module_name):
167+
user_module = user_module.add_file(REQUIREMENTS_FILE)
168+
tar_name = user_module.create_tar()
169+
170+
module = modules.import_module(tar_name, name=user_module_name, cache=False)
171+
172+
assert module.say() == REQUIREMENTS_TXT_ASSERT_STR
173+
174+
os.remove(tar_name)

test/unit/test_files.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import itertools
1414
import logging
1515
import os
16+
import tarfile
1617

1718
from mock import mock_open, patch
1819
import pytest
@@ -113,7 +114,7 @@ def test_write_failure_file():
113114
@patch('os.path.isdir', lambda x: True)
114115
@patch('shutil.rmtree')
115116
@patch('shutil.move')
116-
def test_download_and_and_extract_source_dir(move, rmtree, s3_download):
117+
def test_download_and_extract_source_dir(move, rmtree, s3_download):
117118
uri = _env.channel_path('code')
118119
_files.download_and_extract(uri, _env.code_dir)
119120
s3_download.assert_not_called()
@@ -125,9 +126,24 @@ def test_download_and_and_extract_source_dir(move, rmtree, s3_download):
125126
@patch('sagemaker_containers._files.s3_download')
126127
@patch('os.path.isdir', lambda x: False)
127128
@patch('shutil.copy2')
128-
def test_download_and_and_extract_file(copy, s3_download):
129-
uri = _env.channel_path('code')
129+
def test_download_and_extract_file(copy, s3_download):
130+
uri = __file__
130131
_files.download_and_extract(uri, _env.code_dir)
131132

132133
s3_download.assert_not_called()
133134
copy.assert_called_with(uri, _env.code_dir)
135+
136+
137+
@patch('sagemaker_containers._files.s3_download')
138+
@patch('os.path.isdir', lambda x: False)
139+
@patch('tarfile.TarFile.extractall')
140+
def test_download_and_extract_tar(extractall, s3_download):
141+
t = tarfile.open(name='test.tar.gz', mode='w:gz')
142+
t.close()
143+
uri = t.name
144+
_files.download_and_extract(uri, _env.code_dir)
145+
146+
s3_download.assert_not_called()
147+
extractall.assert_called_with(path=_env.code_dir)
148+
149+
os.remove(uri)

test/unit/test_intermediate_output.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def test_wrong_output():
3737

3838

3939
@patch('inotify_simple.INotify', MagicMock())
40+
@patch('boto3.client', MagicMock())
4041
def test_daemon_process():
4142
intemediate_sync = _intermediate_output.start_sync(S3_BUCKET, REGION)
4243
assert intemediate_sync.daemon is True

0 commit comments

Comments
 (0)