|
19 | 19 | import pytest
|
20 | 20 | from mock import Mock, patch
|
21 | 21 |
|
| 22 | +from contextlib import contextmanager |
22 | 23 | from sagemaker import fw_utils
|
23 | 24 | from sagemaker.utils import name_from_image
|
24 | 25 |
|
|
30 | 31 | TIMESTAMP = '2017-10-10-14-14-15'
|
31 | 32 |
|
32 | 33 |
|
| 34 | +@contextmanager |
| 35 | +def cd(path): |
| 36 | + old_dir = os.getcwd() |
| 37 | + os.chdir(path) |
| 38 | + yield |
| 39 | + os.chdir(old_dir) |
| 40 | + |
| 41 | + |
33 | 42 | @pytest.fixture()
|
34 | 43 | def sagemaker_session():
|
35 | 44 | boto_mock = Mock(name='boto_session', region_name=REGION)
|
@@ -132,7 +141,7 @@ def test_validate_source_dir_file_not_in_dir():
|
132 | 141 |
|
133 | 142 |
|
134 | 143 | def test_tar_and_upload_dir_not_s3(sagemaker_session):
|
135 |
| - bucket = 'mybucker' |
| 144 | + bucket = 'mybucket' |
136 | 145 | s3_key_prefix = 'something/source'
|
137 | 146 | script = os.path.basename(__file__)
|
138 | 147 | directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
@@ -166,6 +175,33 @@ def test_tar_and_upload_dir_no_directory(sagemaker_session, tmpdir):
|
166 | 175 | assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir)
|
167 | 176 |
|
168 | 177 |
|
| 178 | +def test_tar_and_upload_dir_no_directory_only_entrypoint(sagemaker_session, tmpdir): |
| 179 | + source_dir = file_tree(tmpdir, ['train.py', 'not_me.py']) |
| 180 | + entrypoint = os.path.join(source_dir, 'train.py') |
| 181 | + |
| 182 | + with patch('shutil.rmtree'): |
| 183 | + result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None) |
| 184 | + |
| 185 | + assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 186 | + script_name='train.py') |
| 187 | + |
| 188 | + assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) |
| 189 | + |
| 190 | + |
| 191 | +def test_tar_and_upload_dir_no_directory_bare_filename(sagemaker_session, tmpdir): |
| 192 | + source_dir = file_tree(tmpdir, ['train.py']) |
| 193 | + entrypoint = 'train.py' |
| 194 | + |
| 195 | + with patch('shutil.rmtree'): |
| 196 | + with cd(source_dir): |
| 197 | + result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None) |
| 198 | + |
| 199 | + assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 200 | + script_name='train.py') |
| 201 | + |
| 202 | + assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) |
| 203 | + |
| 204 | + |
169 | 205 | def test_tar_and_upload_dir_with_directory(sagemaker_session, tmpdir):
|
170 | 206 | file_tree(tmpdir, ['src-dir/train.py'])
|
171 | 207 | source_dir = os.path.join(str(tmpdir), 'src-dir')
|
|
0 commit comments