Skip to content

Commit e684021

Browse files
committed
Add tests
1 parent df542d3 commit e684021

File tree

2 files changed

+142
-15
lines changed

2 files changed

+142
-15
lines changed

src/sagemaker/fw_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,19 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, additi
123123
Returns:
124124
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.
125125
"""
126-
additional_files = additional_files or []
127126
key = '%s/sourcedir.tar.gz' % s3_key_prefix
128127

129-
basedir = directory if directory else os.path.dirname(script)
130-
131-
if basedir.lower().startswith("s3://"):
132-
s3_prefix = directory
128+
if directory and directory.lower().startswith("s3://"):
129+
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
133130
else:
134-
_upload_code(session, bucket, key, basedir, additional_files)
135-
s3_prefix = 's3://%s/%s' % (bucket, key)
131+
source_files = _list_root_files(script, directory, additional_files)
132+
_upload_code(session, bucket, key, source_files)
136133

137-
script_name = script if directory else os.path.basename(script)
138-
return UploadedCode(s3_prefix=s3_prefix, script_name=script_name)
134+
script_name = script if directory else os.path.basename(script)
135+
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
139136

140137

141-
def _upload_code(session, bucket, key, dirname, additional_files):
142-
source_files = _list_files([dirname] + additional_files)
138+
def _upload_code(session, bucket, key, source_files):
143139
tar_file = sagemaker.utils.create_tar_file(source_files)
144140

145141
try:
@@ -148,13 +144,17 @@ def _upload_code(session, bucket, key, dirname, additional_files):
148144
os.remove(tar_file)
149145

150146

151-
def _list_files(files):
147+
def _list_root_files(script, directory, additional_files):
148+
additional_files = additional_files or []
149+
basedir = directory if directory else os.path.dirname(script)
150+
files = [basedir] + additional_files
151+
152152
for file in files:
153-
if os.path.isdir(file):
153+
if os.path.isfile(file):
154+
yield file
155+
else:
154156
for name in os.listdir(file):
155157
yield os.path.join(file, name)
156-
else:
157-
yield file
158158

159159

160160
def framework_name_from_image(image_name):

tests/unit/test_fw_utils.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
import os
17+
import tarfile
1718

1819
import pytest
1920
from mock import Mock, patch
@@ -143,6 +144,132 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
143144
script)
144145

145146

147+
def file_tree(tmpdir, files=None, folders=None):
148+
files = files or []
149+
folders = folders or []
150+
for file in files:
151+
tmpdir.join(file).ensure(file=True)
152+
153+
for folder in folders:
154+
tmpdir.join(folder).ensure(dir=True)
155+
156+
return str(tmpdir)
157+
158+
159+
def test_tar_and_upload_dir_no_directory(sagemaker_session, tmpdir):
160+
root = file_tree(tmpdir, ['train.py'])
161+
162+
with patch('os.remove'):
163+
result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', os.path.join(root, 'train.py'), None)
164+
165+
assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
166+
script_name='train.py')
167+
168+
assert {'/train.py'} == tarball_files(sagemaker_session, tmpdir)
169+
170+
171+
def test_tar_and_upload_dir_with_directory(sagemaker_session, tmpdir):
172+
file_tree(tmpdir, ['src-dir/train.py'])
173+
174+
root = os.path.join(str(tmpdir), 'src-dir')
175+
176+
with patch('os.remove'):
177+
result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', root)
178+
179+
assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
180+
script_name='train.py')
181+
182+
assert {'/train.py'} == tarball_files(sagemaker_session, tmpdir)
183+
184+
185+
def test_tar_and_upload_dir_with_subdirectory(sagemaker_session, tmpdir):
186+
file_tree(tmpdir, ['src-dir/sub/train.py'])
187+
188+
root = os.path.join(str(tmpdir), 'src-dir')
189+
190+
with patch('os.remove'):
191+
result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', root)
192+
193+
assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
194+
script_name='train.py')
195+
196+
assert {'/sub/train.py'} == tarball_files(sagemaker_session, tmpdir)
197+
198+
199+
def test_tar_and_upload_dir_with_directory_and_files(sagemaker_session, tmpdir):
200+
file_tree(tmpdir, ['src-dir/train.py', 'src-dir/laucher', 'src-dir/module/__init__.py'])
201+
root = os.path.join(str(tmpdir), 'src-dir')
202+
203+
with patch('os.remove'):
204+
result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', root)
205+
206+
assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
207+
script_name='train.py')
208+
209+
assert {'/laucher', '/module/__init__.py', '/train.py'} == tarball_files(sagemaker_session, tmpdir)
210+
211+
212+
def test_tar_and_upload_dir_with_directories_and_files(sagemaker_session, tmpdir):
213+
file_tree(tmpdir, ['src-dir/a/b', 'src-dir/a/b2', 'src-dir/x/y', 'src-dir/x/y2', 'src-dir/z'])
214+
root = os.path.join(str(tmpdir), 'src-dir')
215+
216+
with patch('os.remove'):
217+
result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'a/b', root)
218+
219+
assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
220+
script_name='a/b')
221+
222+
assert {'/a/b', '/a/b2', '/x/y', '/x/y2', '/z'} == tarball_files(sagemaker_session, tmpdir)
223+
224+
225+
def test_tar_and_upload_dir_with_many_folders(sagemaker_session, tmpdir):
226+
file_tree(tmpdir, ['src-dir/a/b', 'src-dir/a/b2', 'common/x/y', 'common/x/y2', 't/y/z'])
227+
root = os.path.join(str(tmpdir), 'src-dir')
228+
additional_files = [os.path.join(str(tmpdir), 'common'), os.path.join(str(tmpdir), 't', 'y', 'z')]
229+
230+
with patch('os.remove'):
231+
result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'model.py', root, additional_files)
232+
233+
assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
234+
script_name='model.py')
235+
236+
assert {'/a/b', '/a/b2', '/x/y', '/x/y2', '/z'} == tarball_files(sagemaker_session, tmpdir)
237+
238+
239+
def test_test_tar_and_upload_dir_with_subfolders(sagemaker_session, tmpdir):
240+
file_tree(tmpdir, ['a/b/c', 'a/b/c2'])
241+
root = file_tree(tmpdir, ['x/y/z', 'x/y/z2'])
242+
243+
with patch('os.remove'):
244+
result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'b/c',
245+
os.path.join(root, 'a'), [os.path.join(root, 'x')])
246+
247+
assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
248+
script_name='b/c')
249+
250+
assert {'/y/z2', '/b/c2', '/b/c', '/y/z'} == tarball_files(sagemaker_session, tmpdir)
251+
252+
253+
def tarball_files(sagemaker_session, tmpdir):
254+
startpath = str(tmpdir.ensure('/opt/ml/code/', dir=True))
255+
256+
tar_ball = sagemaker_session.resource('s3').Object().upload_file.call_args[0][0]
257+
try:
258+
259+
with tarfile.open(name=tar_ball, mode='r:gz') as t:
260+
t.extractall(path=startpath)
261+
262+
def walk():
263+
for root, dirs, files in os.walk(startpath):
264+
path = root.replace(startpath, '')
265+
for f in files:
266+
yield '{}/{}'.format(path, f)
267+
268+
return set(walk())
269+
finally:
270+
os.remove(tar_ball)
271+
272+
146273
def test_framework_name_from_image_mxnet():
147274
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.1-gpu-py3'
148275
assert ('mxnet', 'py3', '1.1-gpu-py3') == framework_name_from_image(image_name)

0 commit comments

Comments
 (0)