Skip to content

Commit 0c89d56

Browse files
author
Jonathan Esterhazy
committed
fix FileNotFoundError when user provides an entry_point but no source_dir
1 parent 0071ff8 commit 0c89d56

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

src/sagemaker/fw_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -131,27 +131,29 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, depend
131131
Returns:
132132
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.
133133
"""
134-
dependencies = dependencies or []
135-
key = '%s/sourcedir.tar.gz' % s3_key_prefix
136-
137134
if directory and directory.lower().startswith('s3://'):
138135
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
139-
else:
140-
tmp = tempfile.mkdtemp()
141136

142-
try:
143-
source_files = _list_files_to_compress(script, directory) + dependencies
144-
tar_file = sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME))
137+
script_name = script if directory else os.path.basename(script)
138+
dependencies = dependencies or []
139+
key = '%s/sourcedir.tar.gz' % s3_key_prefix
140+
tmp = tempfile.mkdtemp()
145141

146-
session.resource('s3').Object(bucket, key).upload_file(tar_file)
147-
finally:
148-
shutil.rmtree(tmp)
142+
try:
143+
source_files = _list_files_to_compress(script, directory) + dependencies
144+
tar_file = sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME))
149145

150-
script_name = script if directory else os.path.basename(script)
151-
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
146+
session.resource('s3').Object(bucket, key).upload_file(tar_file)
147+
finally:
148+
shutil.rmtree(tmp)
149+
150+
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
152151

153152

154153
def _list_files_to_compress(script, directory):
154+
if directory is None:
155+
return [script]
156+
155157
basedir = directory if directory else os.path.dirname(script)
156158
return [os.path.join(basedir, name) for name in os.listdir(basedir)]
157159

tests/unit/test_fw_utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
from mock import Mock, patch
2121

22+
from contextlib import contextmanager
2223
from sagemaker import fw_utils
2324
from sagemaker.utils import name_from_image
2425

@@ -30,6 +31,14 @@
3031
TIMESTAMP = '2017-10-10-14-14-15'
3132

3233

34+
@contextmanager
35+
def cd(path):
36+
old_dir = os.getcwd()
37+
os.chdir(path)
38+
yield
39+
os.chdir(old_dir)
40+
41+
3342
@pytest.fixture()
3443
def sagemaker_session():
3544
boto_mock = Mock(name='boto_session', region_name=REGION)
@@ -132,7 +141,7 @@ def test_validate_source_dir_file_not_in_dir():
132141

133142

134143
def test_tar_and_upload_dir_not_s3(sagemaker_session):
135-
bucket = 'mybucker'
144+
bucket = 'mybucket'
136145
s3_key_prefix = 'something/source'
137146
script = os.path.basename(__file__)
138147
directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
@@ -166,6 +175,33 @@ def test_tar_and_upload_dir_no_directory(sagemaker_session, tmpdir):
166175
assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir)
167176

168177

178+
def test_tar_and_upload_dir_no_directory_only_entrypoint(sagemaker_session, tmpdir):
179+
source_dir = file_tree(tmpdir, ['train.py', 'not_me.py'])
180+
entrypoint = os.path.join(source_dir, 'train.py')
181+
182+
with patch('shutil.rmtree'):
183+
result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None)
184+
185+
assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
186+
script_name='train.py')
187+
188+
assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir)
189+
190+
191+
def test_tar_and_upload_dir_no_directory_bare_filename(sagemaker_session, tmpdir):
192+
source_dir = file_tree(tmpdir, ['train.py'])
193+
entrypoint = 'train.py'
194+
195+
with patch('shutil.rmtree'):
196+
with cd(source_dir):
197+
result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None)
198+
199+
assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
200+
script_name='train.py')
201+
202+
assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir)
203+
204+
169205
def test_tar_and_upload_dir_with_directory(sagemaker_session, tmpdir):
170206
file_tree(tmpdir, ['src-dir/train.py'])
171207
source_dir = os.path.join(str(tmpdir), 'src-dir')

0 commit comments

Comments
 (0)