Skip to content

Commit eeec58d

Browse files
committed
Handle PR comments
1 parent 3bbc3bb commit eeec58d

File tree

5 files changed

+70
-33
lines changed

5 files changed

+70
-33
lines changed

src/sagemaker/fw_utils.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,16 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script,
182182
tmp = tempfile.mkdtemp()
183183

184184
try:
185-
source_files = _list_files_to_compress(script, directory) + dependencies
186-
tar_file = sagemaker.utils.create_tar_file(source_files,
187-
os.path.join(tmp, _TAR_SOURCE_FILENAME))
185+
if directory:
186+
source_files = dependencies
187+
dir_files = [directory]
188+
else:
189+
source_files = [script] + dependencies
190+
dir_files = []
191+
192+
tar_file = sagemaker.utils.create_tar_file(source_files=source_files,
193+
dir_files=dir_files,
194+
target=os.path.join(tmp, _TAR_SOURCE_FILENAME))
188195

189196
if kms_key:
190197
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key}
@@ -198,14 +205,6 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script,
198205
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
199206

200207

201-
def _list_files_to_compress(script, directory):
202-
if directory is None:
203-
return [script]
204-
205-
basedir = directory if directory else os.path.dirname(script)
206-
return [os.path.join(basedir, name) for name in os.listdir(basedir)]
207-
208-
209208
def framework_name_from_image(image_name):
210209
"""Extract the framework and Python version from the image name.
211210

src/sagemaker/local/image.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,10 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name):
234234
elif container_dir == '/opt/ml/output':
235235
sagemaker.local.utils.recursive_copy(host_dir, output_artifacts)
236236

237-
# Tar Artifacts -> model.tar.gz and output.tar.gz
238-
model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)]
239-
output_files = [os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts)]
240-
sagemaker.utils.create_tar_file(model_files, os.path.join(compressed_artifacts, 'model.tar.gz'))
241-
sagemaker.utils.create_tar_file(output_files, os.path.join(compressed_artifacts, 'output.tar.gz'))
237+
sagemaker.utils.create_tar_file(dir_files=[model_artifacts],
238+
target=os.path.join(compressed_artifacts, 'model.tar.gz'))
239+
sagemaker.utils.create_tar_file(dir_files=[output_artifacts],
240+
target=os.path.join(compressed_artifacts, 'output.tar.gz'))
242241

243242
if output_data_config['S3OutputPath'] == '':
244243
output_data = 'file://%s' % compressed_artifacts

src/sagemaker/utils.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,14 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
260260
obj.download_file(file_path)
261261

262262

263-
def create_tar_file(source_files, target=None):
264-
"""Create a tar file containing all the source_files
263+
def create_tar_file(source_files=None, target=None, dir_files=None):
264+
"""Create a tar file containing all the source_files and the content of all dir_files
265265
266266
Args:
267267
source_files (List[str]): List of file paths that will be contained in the tar file
268+
target (str): target path of the tar file
269+
dir_files (List[str]): List of directories which will have their contents copy into
270+
the tar file
268271
269272
Returns:
270273
(str): path to created tar file
@@ -275,10 +278,17 @@ def create_tar_file(source_files, target=None):
275278
else:
276279
_, filename = tempfile.mkstemp()
277280

281+
dir_files = dir_files or []
282+
source_files = source_files or []
283+
278284
with tarfile.open(filename, mode='w:gz') as t:
279285
for sf in source_files:
280286
# Add all files from the directory into the root of the directory structure of the tar
281287
t.add(sf, arcname=os.path.basename(sf))
288+
289+
for dir_file in dir_files:
290+
t.add(dir_file, arcname=os.path.sep)
291+
282292
return filename
283293

284294

@@ -323,13 +333,11 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
323333
new_model_name = 'model-%s.tar.gz' % sagemaker.utils.sagemaker_short_timestamp()
324334

325335
with _tmpdir() as tmp:
326-
327336
tmp_model_dir = os.path.join(tmp, 'model')
328337
os.mkdir(tmp_model_dir)
329338

330339
model_from_s3 = model_uri.startswith('s3://')
331340
if model_from_s3:
332-
333341
local_model_path = os.path.join(tmp, 'tar_file')
334342
download_file_from_url(model_uri, local_model_path, sagemaker_session)
335343

@@ -349,10 +357,7 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
349357

350358
shutil.copytree(dirname, code_dir)
351359

352-
files_to_compress = [os.path.join(tmp_model_dir, file)
353-
for file in os.listdir(tmp_model_dir)]
354-
355-
tar_file = sagemaker.utils.create_tar_file(files_to_compress, new_model_path)
360+
tar_file = sagemaker.utils.create_tar_file(dir_files=[tmp_model_dir], target=new_model_path)
356361

357362
if model_from_s3:
358363
url = parse.urlparse(model_uri)

tests/integ/test_tfs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
5151

5252
def tar_dir(directory, tmpdir):
5353

54-
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
55-
return sagemaker.utils.create_tar_file(source_files, os.path.join(str(tmpdir), 'model.tar.gz'))
54+
return sagemaker.utils.create_tar_file(dir_files=[directory],
55+
target=os.path.join(str(tmpdir), 'model.tar.gz'))
5656

5757

5858
@pytest.fixture(scope='module')

tests/unit/test_utils.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -268,23 +268,57 @@ def test_download_file():
268268

269269
@patch('tarfile.open')
270270
def test_create_tar_file_with_provided_path(open):
271-
open.return_value = open
272-
open.__enter__ = Mock()
273-
open.__exit__ = Mock(return_value=None)
271+
files = mock_tarfile(open)
272+
274273
file_list = ['/tmp/a', '/tmp/b']
274+
275275
path = sagemaker.utils.create_tar_file(file_list, target='/my/custom/path.tar.gz')
276276
assert path == '/my/custom/path.tar.gz'
277+
assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']]
277278

278279

279280
@patch('tarfile.open')
280-
@patch('tempfile.mkstemp', Mock(return_value=(None, '/auto/generated/path')))
281-
def test_create_tar_file_with_auto_generated_path(open):
281+
def test_create_tar_file_with_directories(open):
282+
files = mock_tarfile(open)
283+
284+
path = sagemaker.utils.create_tar_file(dir_files=['/tmp/a', '/tmp/b'],
285+
target='/my/custom/path.tar.gz')
286+
assert path == '/my/custom/path.tar.gz'
287+
assert files == [['/tmp/a', '/'], ['/tmp/b', '/']]
288+
289+
290+
@patch('tarfile.open')
291+
def test_create_tar_file_with_files_and_directories(open):
292+
files = mock_tarfile(open)
293+
294+
path = sagemaker.utils.create_tar_file(dir_files=['/tmp/a', '/tmp/b'],
295+
source_files=['/tmp/c', '/tmp/d'],
296+
target='/my/custom/path.tar.gz')
297+
assert path == '/my/custom/path.tar.gz'
298+
assert files == [['/tmp/c', 'c'], ['/tmp/d', 'd'], ['/tmp/a', '/'], ['/tmp/b', '/']]
299+
300+
301+
def mock_tarfile(open):
282302
open.return_value = open
303+
files = []
304+
305+
def add_files(filename, arcname):
306+
files.append([filename, arcname])
307+
283308
open.__enter__ = Mock()
309+
open.__enter__().add = add_files
284310
open.__exit__ = Mock(return_value=None)
285-
file_list = ['/tmp/a', '/tmp/b']
286-
path = sagemaker.utils.create_tar_file(file_list)
311+
return files
312+
313+
314+
@patch('tarfile.open')
315+
@patch('tempfile.mkstemp', Mock(return_value=(None, '/auto/generated/path')))
316+
def test_create_tar_file_with_auto_generated_path(open):
317+
files = mock_tarfile(open)
318+
319+
path = sagemaker.utils.create_tar_file(['/tmp/a', '/tmp/b'])
287320
assert path == '/auto/generated/path'
321+
assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']]
288322

289323

290324
def write_file(path, content):

0 commit comments

Comments
 (0)