diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 5f35e4f259..889a589905 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -340,9 +340,11 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio if os.path.exists(code_dir): shutil.rmtree(code_dir, ignore_errors=True) - dirname = source_directory if source_directory else os.path.dirname(inference_script) - - shutil.copytree(dirname, code_dir) + if source_directory: + shutil.copytree(source_directory, code_dir) + else: + os.mkdir(code_dir) + shutil.copy2(inference_script, code_dir) with tarfile.open(new_model_path, mode='w:gz') as t: t.add(tmp_model_dir, arcname=os.path.sep) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index efd0ad499a..94511939e3 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -317,6 +317,9 @@ def test_repack_model_without_source_dir(tmpdir): script_path = os.path.join(source_dir, 'inference.py') write_file(script_path, 'inference script') + script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py') + write_file(script_path, 'This file should not be included') + contents = [model_path] sagemaker_session = MagicMock() @@ -334,6 +337,44 @@ def test_repack_model_without_source_dir(tmpdir): assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) +def test_repack_model_with_entry_point_without_path_without_source_dir(tmpdir): + + tmp = str(tmpdir) + + model_path = os.path.join(tmp, 'model') + write_file(model_path, 'model data') + + source_dir = os.path.join(tmp, 'source-dir') + os.mkdir(source_dir) + script_path = os.path.join(source_dir, 'inference.py') + write_file(script_path, 'inference script') + + script_path = os.path.join(source_dir, 'this-file-should-not-be-included.py') + write_file(script_path, 'This file should not be included') + + contents = [model_path] + + sagemaker_session = MagicMock() + mock_s3_model_tar(contents, sagemaker_session, tmp) + fake_upload_path = mock_s3_upload(sagemaker_session, tmp) + + model_uri = 's3://fake/location' + + cwd = os.getcwd() + try: + os.chdir(source_dir) + + new_model_uri = sagemaker.utils.repack_model('inference.py', + None, + model_uri, + sagemaker_session) + finally: + os.chdir(cwd) + + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} + assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri) + + def test_repack_model_from_s3_saved_model_to_s3(tmpdir): tmp = str(tmpdir) @@ -346,6 +387,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir): script_path = os.path.join(source_dir, 'inference.py') write_file(script_path, 'inference script') + script_path = os.path.join(source_dir, 'this-file-should-be-included.py') + write_file(script_path, 'This file should be included') + contents = [model_path] sagemaker_session = MagicMock() @@ -359,7 +403,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir): model_uri, sagemaker_session) - assert list_tar_files(fake_upload_path, tmpdir) == {'/code/inference.py', '/model'} + assert list_tar_files(fake_upload_path, tmpdir) == {'/code/this-file-should-be-included.py', + '/code/inference.py', + '/model'} assert re.match(r'^s3://fake/model-\d+-\d+.tar.gz$', new_model_uri)