Skip to content

Commit a07a962

Browse files
author
Payton Staub
committed
fix: Prevent repack_model script from referencing nonexistent directories
1 parent b52b5db commit a07a962

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

src/sagemaker/workflow/_repack_model.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
6262
with tarfile.open(name=local_path, mode="r:gz") as tf:
6363
tf.extractall(path=src_dir)
6464

65-
# copy the custom inference script to code/
66-
entry_point = os.path.join("/opt/ml/code", inference_script)
67-
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script))
68-
69-
# copy source_dir to code/
7065
if source_dir:
66+
# copy /opt/ml/code to code/
7167
if os.path.exists(code_dir):
7268
shutil.rmtree(code_dir)
73-
shutil.copytree(source_dir, code_dir)
69+
shutil.copytree("/opt/ml/code", code_dir)
70+
else:
71+
# copy the custom inference script to code/
72+
entry_point = os.path.join("/opt/ml/code", inference_script)
73+
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))
7474

7575
# copy any dependencies to code/lib/
7676
if dependencies:
@@ -79,13 +79,20 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
7979
lib_dir = os.path.join(code_dir, "lib")
8080
if not os.path.exists(lib_dir):
8181
os.mkdir(lib_dir)
82-
if os.path.isdir(actual_dependency_path):
83-
shutil.copytree(
84-
actual_dependency_path,
85-
os.path.join(lib_dir, os.path.basename(actual_dependency_path)),
86-
)
87-
else:
82+
if os.path.isfile(actual_dependency_path):
8883
shutil.copy2(actual_dependency_path, lib_dir)
84+
else:
85+
if os.path.exists(lib_dir):
86+
shutil.rmtree(lib_dir)
87+
# a directory is in the dependencies. we have no choice but to copy
88+
# all of /opt/ml/code into the lib dir because the original directory
89+
# was flattened by the SDK training job upload..
90+
shutil.copytree("/opt/ml/code", lib_dir)
91+
break
92+
# shutil.copytree(
93+
# actual_dependency_path,
94+
# os.path.join(lib_dir, os.path.basename(actual_dependency_path)),
95+
# )
8996

9097
# copy the "src" dir, which includes the previous training job's model and the
9198
# custom inference script, to the output of this training job

0 commit comments

Comments
 (0)