Skip to content

Commit d8c055b

Browse files
authored
fix FileNotFoundError for entry_point without source_dir (#510)
* fix FileNotFoundError when user provides an entry_point but no source_dir * update docstring * bump version to 1.15.2 * fix integ tests failing due to breaking ExecuteUserScriptError change
1 parent d4f0f60 commit d8c055b

File tree

9 files changed

+85
-31
lines changed

9 files changed

+85
-31
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
CHANGELOG
33
=========
44

5+
1.15.2
6+
======
7+
8+
* bug-fix: Fix FileNotFoundError for entry_point without source_dir
9+
* doc-fix: Add missing feature 1.5.0 in change log
10+
* doc-fix: Add README for airflow
11+
512
1.15.1
613
======
714

doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __getattr__(cls, name):
3232
'numpy', 'scipy', 'scipy.sparse']
3333
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
3434

35-
version = '1.15.1'
35+
version = '1.15.2'
3636
project = u'sagemaker'
3737

3838
# Add any Sphinx extension module names here, as strings. They can be extensions

src/sagemaker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@
3737
from sagemaker.session import s3_input # noqa: F401
3838
from sagemaker.session import get_execution_role # noqa: F401
3939

40-
__version__ = '1.15.1'
40+
__version__ = '1.15.2'

src/sagemaker/fw_utils.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from collections import namedtuple
16+
1517
import os
1618
import re
19+
import sagemaker.utils
1720
import shutil
1821
import tempfile
19-
from collections import namedtuple
2022
from six.moves.urllib.parse import urlparse
2123

22-
import sagemaker.utils
23-
2424
_TAR_SOURCE_FILENAME = 'source.tar.gz'
2525

2626
UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name'])
@@ -112,46 +112,57 @@ def validate_source_dir(script, directory):
112112

113113

114114
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, dependencies=None):
115-
"""Pack and upload source files to S3 only if directory is empty or local.
115+
"""Package source files and upload a compress tar file to S3. The S3 location will be
116+
``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
117+
118+
If directory is an S3 URI, an UploadedCode object will be returned, but nothing will be
119+
uploaded to S3 (this allow reuse of code already in S3).
120+
121+
If directory is None, the script will be added to the archive at ``./<basename of script>``.
116122
117-
Note:
118-
If the directory points to S3 no action is taken.
123+
If directory is not None, the (recursive) contents of the directory will be added to
124+
the archive. directory is treated as the base path of the archive, and the script name is
125+
assumed to be a filename or relative path inside the directory.
119126
120127
Args:
121128
session (boto3.Session): Boto session used to access S3.
122129
bucket (str): S3 bucket to which the compressed file is uploaded.
123130
s3_key_prefix (str): Prefix for the S3 key.
124-
script (str): Script filename.
125-
directory (str or None): Directory containing the source file. If it starts with
126-
"s3://", no action is taken.
127-
dependencies (List[str]): A list of paths to directories (absolute or relative)
131+
script (str): Script filename or path.
132+
directory (str): Optional. Directory containing the source file. If it starts with "s3://",
133+
no action is taken.
134+
dependencies (List[str]): Optional. A list of paths to directories (absolute or relative)
128135
containing additional libraries that will be copied into
129136
/opt/ml/lib
130137
131138
Returns:
132-
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.
139+
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
140+
script name.
133141
"""
134-
dependencies = dependencies or []
135-
key = '%s/sourcedir.tar.gz' % s3_key_prefix
136-
137142
if directory and directory.lower().startswith('s3://'):
138143
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
139-
else:
140-
tmp = tempfile.mkdtemp()
141144

142-
try:
143-
source_files = _list_files_to_compress(script, directory) + dependencies
144-
tar_file = sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME))
145+
script_name = script if directory else os.path.basename(script)
146+
dependencies = dependencies or []
147+
key = '%s/sourcedir.tar.gz' % s3_key_prefix
148+
tmp = tempfile.mkdtemp()
145149

146-
session.resource('s3').Object(bucket, key).upload_file(tar_file)
147-
finally:
148-
shutil.rmtree(tmp)
150+
try:
151+
source_files = _list_files_to_compress(script, directory) + dependencies
152+
tar_file = sagemaker.utils.create_tar_file(source_files,
153+
os.path.join(tmp, _TAR_SOURCE_FILENAME))
149154

150-
script_name = script if directory else os.path.basename(script)
151-
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
155+
session.resource('s3').Object(bucket, key).upload_file(tar_file)
156+
finally:
157+
shutil.rmtree(tmp)
158+
159+
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
152160

153161

154162
def _list_files_to_compress(script, directory):
163+
if directory is None:
164+
return [script]
165+
155166
basedir = directory if directory else os.path.dirname(script)
156167
return [os.path.join(basedir, name) for name in os.listdir(basedir)]
157168

tests/integ/test_chainer_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_failed_training_job(sagemaker_session, chainer_full_version):
115115

116116
with pytest.raises(ValueError) as e:
117117
chainer.fit()
118-
assert 'This failure is expected' in str(e.value)
118+
assert 'ExecuteUserScriptError' in str(e.value)
119119

120120

121121
def _run_mnist_training_job(sagemaker_session, instance_type, instance_count,

tests/integ/test_mxnet_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,4 @@ def test_failed_training_job(sagemaker_session, mxnet_full_version):
112112

113113
with pytest.raises(ValueError) as e:
114114
mx.fit()
115-
assert 'This failure is expected' in str(e.value)
115+
assert 'ExecuteUserScriptError' in str(e.value)

tests/integ/test_pytorch_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_failed_training_job(sagemaker_session, pytorch_full_version):
109109

110110
with pytest.raises(ValueError) as e:
111111
pytorch.fit()
112-
assert 'This failure is expected' in str(e.value)
112+
assert 'ExecuteUserScriptError' in str(e.value)
113113

114114

115115
def _upload_training_data(pytorch):

tests/integ/test_tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,4 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
162162

163163
with pytest.raises(ValueError) as e:
164164
estimator.fit()
165-
assert 'This failure is expected' in str(e.value)
165+
assert 'ExecuteUserScriptError' in str(e.value)

tests/unit/test_fw_utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
from mock import Mock, patch
2121

22+
from contextlib import contextmanager
2223
from sagemaker import fw_utils
2324
from sagemaker.utils import name_from_image
2425

@@ -30,6 +31,14 @@
3031
TIMESTAMP = '2017-10-10-14-14-15'
3132

3233

34+
@contextmanager
35+
def cd(path):
36+
old_dir = os.getcwd()
37+
os.chdir(path)
38+
yield
39+
os.chdir(old_dir)
40+
41+
3342
@pytest.fixture()
3443
def sagemaker_session():
3544
boto_mock = Mock(name='boto_session', region_name=REGION)
@@ -132,7 +141,7 @@ def test_validate_source_dir_file_not_in_dir():
132141

133142

134143
def test_tar_and_upload_dir_not_s3(sagemaker_session):
135-
bucket = 'mybucker'
144+
bucket = 'mybucket'
136145
s3_key_prefix = 'something/source'
137146
script = os.path.basename(__file__)
138147
directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
@@ -166,6 +175,33 @@ def test_tar_and_upload_dir_no_directory(sagemaker_session, tmpdir):
166175
assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir)
167176

168177

178+
def test_tar_and_upload_dir_no_directory_only_entrypoint(sagemaker_session, tmpdir):
179+
source_dir = file_tree(tmpdir, ['train.py', 'not_me.py'])
180+
entrypoint = os.path.join(source_dir, 'train.py')
181+
182+
with patch('shutil.rmtree'):
183+
result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None)
184+
185+
assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
186+
script_name='train.py')
187+
188+
assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir)
189+
190+
191+
def test_tar_and_upload_dir_no_directory_bare_filename(sagemaker_session, tmpdir):
192+
source_dir = file_tree(tmpdir, ['train.py'])
193+
entrypoint = 'train.py'
194+
195+
with patch('shutil.rmtree'):
196+
with cd(source_dir):
197+
result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None)
198+
199+
assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz',
200+
script_name='train.py')
201+
202+
assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir)
203+
204+
169205
def test_tar_and_upload_dir_with_directory(sagemaker_session, tmpdir):
170206
file_tree(tmpdir, ['src-dir/train.py'])
171207
source_dir = os.path.join(str(tmpdir), 'src-dir')

0 commit comments

Comments
 (0)