Skip to content

Commit 508a8a2

Browse files
committed
Update script path
1 parent 7cd8cbb commit 508a8a2

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

test/integration/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@
3030
model_cpu_tar = file_utils.make_tarfile(mnist_cpu_script,
3131
os.path.join(model_cpu_dir, "model.pth"),
3232
model_cpu_dir,
33-
preserve_script_path=True)
33+
script_path="code")
3434

3535
model_cpu_1d_dir = os.path.join(model_cpu_dir, '1d')
3636
mnist_1d_script = os.path.join(model_cpu_1d_dir, code_sub_dir, 'mnist_1d.py')
3737
model_cpu_1d_tar = file_utils.make_tarfile(mnist_1d_script,
3838
os.path.join(model_cpu_1d_dir, "model.pth"),
3939
model_cpu_1d_dir,
40-
preserve_script_path=True)
40+
script_path="code")
4141

4242
model_gpu_dir = os.path.join(mnist_path, gpu_sub_dir)
4343
mnist_gpu_script = os.path.join(model_gpu_dir, code_sub_dir, 'mnist.py')
4444
model_gpu_tar = file_utils.make_tarfile(mnist_gpu_script,
4545
os.path.join(model_gpu_dir, "model.pth"),
4646
model_gpu_dir,
47-
preserve_script_path=True)
47+
script_path="code")
4848

4949
model_eia_dir = os.path.join(mnist_path, eia_sub_dir)
5050
mnist_eia_script = os.path.join(model_eia_dir, 'mnist.py')
@@ -57,7 +57,7 @@
5757
os.path.join(model_cpu_dir, "model.pth"),
5858
model_cpu_dir,
5959
"model_call_model_fn_once.tar.gz",
60-
preserve_script_path=True)
60+
script_path="code")
6161

6262
ROLE = 'dummy/unused-role'
6363
DEFAULT_TIMEOUT = 20

test/utils/file_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
import tarfile
1717

1818

19-
def make_tarfile(script, model, output_path, filename="model.tar.gz", preserve_script_path=False):
19+
def make_tarfile(script, model, output_path, filename="model.tar.gz", script_path=None):
2020
output_filename = os.path.join(output_path, filename)
2121
with tarfile.open(output_filename, "w:gz") as tar:
22-
if(preserve_script_path):
23-
tar.add(script, arcname=script)
22+
if(script_path):
23+
tar.add(script, arcname=os.path.join(script_path, os.path.basename(script)))
2424
else:
2525
tar.add(script, arcname=os.path.basename(script))
2626
tar.add(model, arcname=os.path.basename(model))

0 commit comments

Comments
 (0)