|
14 | 14 |
|
15 | 15 | import inspect
|
16 | 16 | import os
|
| 17 | +import tarfile |
17 | 18 |
|
18 | 19 | import pytest
|
19 | 20 | from mock import Mock, patch
|
@@ -143,6 +144,132 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
|
143 | 144 | script)
|
144 | 145 |
|
145 | 146 |
|
| 147 | +def file_tree(tmpdir, files=None, folders=None): |
| 148 | + files = files or [] |
| 149 | + folders = folders or [] |
| 150 | + for file in files: |
| 151 | + tmpdir.join(file).ensure(file=True) |
| 152 | + |
| 153 | + for folder in folders: |
| 154 | + tmpdir.join(folder).ensure(dir=True) |
| 155 | + |
| 156 | + return str(tmpdir) |
| 157 | + |
| 158 | + |
| 159 | +def test_tar_and_upload_dir_no_directory(sagemaker_session, tmpdir): |
| 160 | + root = file_tree(tmpdir, ['train.py']) |
| 161 | + |
| 162 | + with patch('os.remove'): |
| 163 | + result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', os.path.join(root, 'train.py'), None) |
| 164 | + |
| 165 | + assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 166 | + script_name='train.py') |
| 167 | + |
| 168 | + assert {'/train.py'} == tarball_files(sagemaker_session, tmpdir) |
| 169 | + |
| 170 | + |
| 171 | +def test_tar_and_upload_dir_with_directory(sagemaker_session, tmpdir): |
| 172 | + file_tree(tmpdir, ['src-dir/train.py']) |
| 173 | + |
| 174 | + root = os.path.join(str(tmpdir), 'src-dir') |
| 175 | + |
| 176 | + with patch('os.remove'): |
| 177 | + result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', root) |
| 178 | + |
| 179 | + assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 180 | + script_name='train.py') |
| 181 | + |
| 182 | + assert {'/train.py'} == tarball_files(sagemaker_session, tmpdir) |
| 183 | + |
| 184 | + |
| 185 | +def test_tar_and_upload_dir_with_subdirectory(sagemaker_session, tmpdir): |
| 186 | + file_tree(tmpdir, ['src-dir/sub/train.py']) |
| 187 | + |
| 188 | + root = os.path.join(str(tmpdir), 'src-dir') |
| 189 | + |
| 190 | + with patch('os.remove'): |
| 191 | + result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', root) |
| 192 | + |
| 193 | + assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 194 | + script_name='train.py') |
| 195 | + |
| 196 | + assert {'/sub/train.py'} == tarball_files(sagemaker_session, tmpdir) |
| 197 | + |
| 198 | + |
| 199 | +def test_tar_and_upload_dir_with_directory_and_files(sagemaker_session, tmpdir): |
| 200 | + file_tree(tmpdir, ['src-dir/train.py', 'src-dir/laucher', 'src-dir/module/__init__.py']) |
| 201 | + root = os.path.join(str(tmpdir), 'src-dir') |
| 202 | + |
| 203 | + with patch('os.remove'): |
| 204 | + result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', root) |
| 205 | + |
| 206 | + assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 207 | + script_name='train.py') |
| 208 | + |
| 209 | + assert {'/laucher', '/module/__init__.py', '/train.py'} == tarball_files(sagemaker_session, tmpdir) |
| 210 | + |
| 211 | + |
| 212 | +def test_tar_and_upload_dir_with_directories_and_files(sagemaker_session, tmpdir): |
| 213 | + file_tree(tmpdir, ['src-dir/a/b', 'src-dir/a/b2', 'src-dir/x/y', 'src-dir/x/y2', 'src-dir/z']) |
| 214 | + root = os.path.join(str(tmpdir), 'src-dir') |
| 215 | + |
| 216 | + with patch('os.remove'): |
| 217 | + result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'a/b', root) |
| 218 | + |
| 219 | + assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 220 | + script_name='a/b') |
| 221 | + |
| 222 | + assert {'/a/b', '/a/b2', '/x/y', '/x/y2', '/z'} == tarball_files(sagemaker_session, tmpdir) |
| 223 | + |
| 224 | + |
| 225 | +def test_tar_and_upload_dir_with_many_folders(sagemaker_session, tmpdir): |
| 226 | + file_tree(tmpdir, ['src-dir/a/b', 'src-dir/a/b2', 'common/x/y', 'common/x/y2', 't/y/z']) |
| 227 | + root = os.path.join(str(tmpdir), 'src-dir') |
| 228 | + additional_files = [os.path.join(str(tmpdir), 'common'), os.path.join(str(tmpdir), 't', 'y', 'z')] |
| 229 | + |
| 230 | + with patch('os.remove'): |
| 231 | + result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'model.py', root, additional_files) |
| 232 | + |
| 233 | + assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 234 | + script_name='model.py') |
| 235 | + |
| 236 | + assert {'/a/b', '/a/b2', '/x/y', '/x/y2', '/z'} == tarball_files(sagemaker_session, tmpdir) |
| 237 | + |
| 238 | + |
| 239 | +def test_test_tar_and_upload_dir_with_subfolders(sagemaker_session, tmpdir): |
| 240 | + file_tree(tmpdir, ['a/b/c', 'a/b/c2']) |
| 241 | + root = file_tree(tmpdir, ['x/y/z', 'x/y/z2']) |
| 242 | + |
| 243 | + with patch('os.remove'): |
| 244 | + result = tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'b/c', |
| 245 | + os.path.join(root, 'a'), [os.path.join(root, 'x')]) |
| 246 | + |
| 247 | + assert result == UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', |
| 248 | + script_name='b/c') |
| 249 | + |
| 250 | + assert {'/y/z2', '/b/c2', '/b/c', '/y/z'} == tarball_files(sagemaker_session, tmpdir) |
| 251 | + |
| 252 | + |
| 253 | +def tarball_files(sagemaker_session, tmpdir): |
| 254 | + startpath = str(tmpdir.ensure('/opt/ml/code/', dir=True)) |
| 255 | + |
| 256 | + tar_ball = sagemaker_session.resource('s3').Object().upload_file.call_args[0][0] |
| 257 | + try: |
| 258 | + |
| 259 | + with tarfile.open(name=tar_ball, mode='r:gz') as t: |
| 260 | + t.extractall(path=startpath) |
| 261 | + |
| 262 | + def walk(): |
| 263 | + for root, dirs, files in os.walk(startpath): |
| 264 | + path = root.replace(startpath, '') |
| 265 | + for f in files: |
| 266 | + yield '{}/{}'.format(path, f) |
| 267 | + |
| 268 | + return set(walk()) |
| 269 | + finally: |
| 270 | + os.remove(tar_ball) |
| 271 | + |
| 272 | + |
146 | 273 | def test_framework_name_from_image_mxnet():
|
147 | 274 | image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.1-gpu-py3'
|
148 | 275 | assert ('mxnet', 'py3', '1.1-gpu-py3') == framework_name_from_image(image_name)
|
|
0 commit comments