@@ -330,27 +330,24 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
330
330
model_from_s3 = model_uri .startswith ('s3://' )
331
331
if model_from_s3 :
332
332
333
- local_model_uri = os .path .join (tmp , 'tar_file' )
334
- download_file_from_url (model_uri , local_model_uri , sagemaker_session )
333
+ local_model_path = os .path .join (tmp , 'tar_file' )
334
+ download_file_from_url (model_uri , local_model_path , sagemaker_session )
335
335
336
336
new_model_path = os .path .join (tmp , new_model_name )
337
337
else :
338
- local_model_uri = model_uri .replace ('file://' , '' )
339
- new_model_path = os .path .join (os .path .dirname (local_model_uri ), new_model_name )
338
+ local_model_path = model_uri .replace ('file://' , '' )
339
+ new_model_path = os .path .join (os .path .dirname (local_model_path ), new_model_name )
340
340
341
- with tarfile .open (name = local_model_uri , mode = 'r:gz' ) as t :
341
+ with tarfile .open (name = local_model_path , mode = 'r:gz' ) as t :
342
342
t .extractall (path = tmp_model_dir )
343
343
344
344
code_dir = os .path .join (tmp_model_dir , 'code' )
345
345
if os .path .exists (code_dir ):
346
346
shutil .rmtree (code_dir , ignore_errors = True )
347
347
348
- os .mkdir ( code_dir )
348
+ dirname = source_directory if source_directory else os .path . dirname ( inference_script )
349
349
350
- source_files = _list_files (inference_script , source_directory )
351
-
352
- for source_file in source_files :
353
- shutil .copy (source_file , code_dir )
350
+ shutil .copytree (dirname , code_dir )
354
351
355
352
files_to_compress = [os .path .join (tmp_model_dir , file )
356
353
for file in os .listdir (tmp_model_dir )]
0 commit comments