Skip to content

Commit fde4d9f

Browse files
committed
Handle PR comments
1 parent c354878 commit fde4d9f

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

src/sagemaker/utils.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -330,27 +330,24 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
330330
model_from_s3 = model_uri.startswith('s3://')
331331
if model_from_s3:
332332

333-
local_model_uri = os.path.join(tmp, 'tar_file')
334-
download_file_from_url(model_uri, local_model_uri, sagemaker_session)
333+
local_model_path = os.path.join(tmp, 'tar_file')
334+
download_file_from_url(model_uri, local_model_path, sagemaker_session)
335335

336336
new_model_path = os.path.join(tmp, new_model_name)
337337
else:
338-
local_model_uri = model_uri.replace('file://', '')
339-
new_model_path = os.path.join(os.path.dirname(local_model_uri), new_model_name)
338+
local_model_path = model_uri.replace('file://', '')
339+
new_model_path = os.path.join(os.path.dirname(local_model_path), new_model_name)
340340

341-
with tarfile.open(name=local_model_uri, mode='r:gz') as t:
341+
with tarfile.open(name=local_model_path, mode='r:gz') as t:
342342
t.extractall(path=tmp_model_dir)
343343

344344
code_dir = os.path.join(tmp_model_dir, 'code')
345345
if os.path.exists(code_dir):
346346
shutil.rmtree(code_dir, ignore_errors=True)
347347

348-
os.mkdir(code_dir)
348+
dirname = source_directory if source_directory else os.path.dirname(inference_script)
349349

350-
source_files = _list_files(inference_script, source_directory)
351-
352-
for source_file in source_files:
353-
shutil.copy(source_file, code_dir)
350+
shutil.copytree(dirname, code_dir)
354351

355352
files_to_compress = [os.path.join(tmp_model_dir, file)
356353
for file in os.listdir(tmp_model_dir)]

tests/integ/test_tfs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,21 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
5050
yield predictor
5151

5252

53-
def tar_dir(directory):
54-
55-
tmp = tempfile.mkdtemp()
53+
def tar_dir(directory, tmpdir):
5654

5755
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
58-
return sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, 'model.tar.gz'))
56+
return sagemaker.utils.create_tar_file(source_files, os.path.join(str(tmpdir), 'model.tar.gz'))
5957

6058

6159
@pytest.fixture(scope='module')
6260
def tfs_predictor_with_model_and_entry_point_same_tar(instance_type,
6361
sagemaker_session,
64-
tf_full_version):
62+
tf_full_version,
63+
tmpdir):
6564
endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving')
6665

67-
model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference'))
66+
model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference'),
67+
tmpdir)
6868

6969
model_data = sagemaker_session.upload_data(
7070
path=model_tar,

0 commit comments

Comments
 (0)