Skip to content

Commit 493d11d

Browse files
mvsusppengk19
authored andcommitted
fix: repack model function works without source directory (aws#804)
1 parent f8cba72 commit 493d11d

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

src/sagemaker/utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,11 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
340340
if os.path.exists(code_dir):
341341
shutil.rmtree(code_dir, ignore_errors=True)
342342

343-
dirname = source_directory if source_directory else os.path.dirname(inference_script)
344-
345-
shutil.copytree(dirname, code_dir)
343+
if source_directory:
344+
shutil.copytree(source_directory, code_dir)
345+
else:
346+
os.mkdir(code_dir)
347+
shutil.copy2(inference_script, code_dir)
346348

347349
with tarfile.open(new_model_path, mode='w:gz') as t:
348350
t.add(tmp_model_dir, arcname=os.path.sep)

tests/unit/test_utils.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ def test_repack_model_without_source_dir(tmpdir):
317317
script_path = os.path.join(source_dir, 'inference.py')
318318
write_file(script_path, 'inference script')
319319

320+
script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py')
321+
write_file(script_path, 'This file should not be included')
322+
320323
contents = [model_path]
321324

322325
sagemaker_session = MagicMock()
@@ -334,6 +337,44 @@ def test_repack_model_without_source_dir(tmpdir):
334337
assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri)
335338

336339

340+
def test_repack_model_with_entry_point_without_path_without_source_dir(tmpdir):
341+
342+
tmp = str(tmpdir)
343+
344+
model_path = os.path.join(tmp, 'model')
345+
write_file(model_path, 'model data')
346+
347+
source_dir = os.path.join(tmp, 'source-dir')
348+
os.mkdir(source_dir)
349+
script_path = os.path.join(source_dir, 'inference.py')
350+
write_file(script_path, 'inference script')
351+
352+
script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py')
353+
write_file(script_path, 'This file should not be included')
354+
355+
contents = [model_path]
356+
357+
sagemaker_session = MagicMock()
358+
mock_s3_model_tar(contents, sagemaker_session, tmp)
359+
fake_upload_path = mock_s3_upload(sagemaker_session, tmp)
360+
361+
model_uri = 's3://fake/location'
362+
363+
cwd = os.getcwd()
364+
try:
365+
os.chdir(source_dir)
366+
367+
new_model_uri = sagemaker.utils.repack_model('inference.py',
368+
None,
369+
model_uri,
370+
sagemaker_session)
371+
finally:
372+
os.chdir(cwd)
373+
374+
assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'}
375+
assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri)
376+
377+
337378
def test_repack_model_from_s3_saved_model_to_s3(tmpdir):
338379

339380
tmp = str(tmpdir)
@@ -346,6 +387,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir):
346387
script_path = os.path.join(source_dir, 'inference.py')
347388
write_file(script_path, 'inference script')
348389

390+
script_path = os.path.join(source_dir, 'this-file-should-be-included.py')
391+
write_file(script_path, 'This file should be included')
392+
349393
contents = [model_path]
350394

351395
sagemaker_session = MagicMock()
@@ -359,7 +403,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir):
359403
model_uri,
360404
sagemaker_session)
361405

362-
assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'}
406+
assert list_tar_files(fake_upload_path, tmpdir) == {'/code/this-file-should-be-included.py',
407+
'/code/inference.py',
408+
'/model'}
363409
assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri)
364410

365411

0 commit comments

Comments
 (0)