Skip to content

Commit 7e6e5ff

Browse files
committed
Simplify function
1 parent aa62d3d commit 7e6e5ff

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

src/sagemaker/fw_utils.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,30 +140,20 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, lib_di
140140
tmp = tempfile.mkdtemp()
141141

142142
try:
143-
source_files = list(_expand_files_to_compress(script, directory))
144-
145-
tar_file = sagemaker.utils.create_tar_file(source_files + lib_dirs, os.path.join(tmp, _TAR_SOURCE_FILENAME))
143+
source_files = _list_files_to_compress(script, directory) + lib_dirs
144+
tar_file = sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME))
146145

147146
session.resource('s3').Object(bucket, key).upload_file(tar_file)
148-
149147
finally:
150148
shutil.rmtree(tmp)
151149

152150
script_name = script if directory else os.path.basename(script)
153151
return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name)
154152

155153

156-
def _expand_files_to_compress(script, directory, additional_files=None):
157-
additional_files = additional_files or []
154+
def _list_files_to_compress(script, directory):
158155
basedir = directory if directory else os.path.dirname(script)
159-
files = [basedir] + additional_files
160-
161-
for file in files:
162-
if os.path.isfile(file):
163-
yield file
164-
else:
165-
for name in os.listdir(file):
166-
yield os.path.join(file, name)
156+
return [os.path.join(basedir, name) for name in os.listdir(basedir)]
167157

168158

169159
def framework_name_from_image(image_name):

0 commit comments

Comments
 (0)