29
29
30
30
import six
31
31
32
- import sagemaker
33
-
34
32
ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$'
35
33
36
34
@@ -300,7 +298,12 @@ def _tmpdir(suffix='', prefix='tmp'):
300
298
shutil .rmtree (tmp )
301
299
302
300
303
- def repack_model (inference_script , source_directory , model_uri , sagemaker_session ):
301
+ def repack_model (inference_script ,
302
+ source_directory ,
303
+ dependencies ,
304
+ model_uri ,
305
+ repacked_model_uri ,
306
+ sagemaker_session ):
304
307
"""Unpack model tarball and creates a new model tarball with the provided code script.
305
308
306
309
This function does the following:
@@ -311,60 +314,91 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
311
314
Args:
312
315
inference_script (str): path or basename of the inference script that will be packed into the model
313
316
source_directory (str): path including all the files that will be packed into the model
317
+ dependencies (list[str]): A list of paths to directories (absolute or relative) with
318
+ any additional libraries that will be exported to the container (default: []).
319
+ The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
320
+ Example:
321
+
322
+ The following call
323
+ >>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
324
+ results in the following inside the container:
325
+
326
+ >>> $ ls
327
+
328
+ >>> opt/ml/code
329
+ >>> |------ train.py
330
+ >>> |------ common
331
+ >>> |------ virtual-env
332
+
333
+ repacked_model_uri (str): path or file system location where the new model will be saved
314
334
model_uri (str): S3 or file system location of the original model tar
315
335
sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3.
316
336
317
337
Returns:
318
338
str: path to the new packed model
319
339
"""
320
- new_model_name = 'model-%s.tar.gz' % sagemaker . utils . sagemaker_short_timestamp ()
340
+ dependencies = dependencies or []
321
341
322
342
with _tmpdir () as tmp :
323
- tmp_model_dir = os .path .join (tmp , 'model' )
324
- os .mkdir (tmp_model_dir )
343
+ model_dir = _extract_model (model_uri , sagemaker_session , tmp )
325
344
326
- model_from_s3 = model_uri .lower ().startswith ('s3://' )
327
- if model_from_s3 :
328
- local_model_path = os .path .join (tmp , 'tar_file' )
329
- download_file_from_url (model_uri , local_model_path , sagemaker_session )
345
+ _create_or_update_code_dir (model_dir , inference_script , source_directory , dependencies , sagemaker_session , tmp )
330
346
331
- new_model_path = os .path .join (tmp , new_model_name )
332
- else :
333
- local_model_path = model_uri .replace ('file://' , '' )
334
- new_model_path = os .path .join (os .path .dirname (local_model_path ), new_model_name )
347
+ tmp_model_path = os .path .join (tmp , 'temp-model.tar.gz' )
348
+ with tarfile .open (tmp_model_path , mode = 'w:gz' ) as t :
349
+ t .add (model_dir , arcname = os .path .sep )
335
350
336
- with tarfile .open (name = local_model_path , mode = 'r:gz' ) as t :
337
- t .extractall (path = tmp_model_dir )
351
+ _save_model (repacked_model_uri , tmp_model_path , sagemaker_session )
338
352
339
- code_dir = os .path .join (tmp_model_dir , 'code' )
340
- if os .path .exists (code_dir ):
341
- shutil .rmtree (code_dir , ignore_errors = True )
342
353
343
- if source_directory and source_directory .lower ().startswith ('s3://' ):
344
- local_code_path = os .path .join (tmp , 'local_code.tar.gz' )
345
- download_file_from_url (source_directory , local_code_path , sagemaker_session )
354
+ def _save_model (repacked_model_uri , tmp_model_path , sagemaker_session ):
355
+ if repacked_model_uri .lower ().startswith ('s3://' ):
356
+ url = parse .urlparse (repacked_model_uri )
357
+ bucket , key = url .netloc , url .path .lstrip ('/' )
358
+ new_key = key .replace (os .path .basename (key ), os .path .basename (repacked_model_uri ))
346
359
347
- with tarfile .open (name = local_code_path , mode = 'r:gz' ) as t :
348
- t .extractall (path = code_dir )
360
+ sagemaker_session .boto_session .resource ('s3' ).Object (bucket , new_key ).upload_file (
361
+ tmp_model_path )
362
+ else :
363
+ shutil .move (tmp_model_path , repacked_model_uri .replace ('file://' , '' ))
349
364
350
- elif source_directory :
351
- shutil .copytree (source_directory , code_dir )
352
- else :
353
- os .mkdir (code_dir )
354
- shutil .copy2 (inference_script , code_dir )
355
365
356
- with tarfile .open (new_model_path , mode = 'w:gz' ) as t :
357
- t .add (tmp_model_dir , arcname = os .path .sep )
366
+ def _create_or_update_code_dir (model_dir , inference_script , source_directory ,
367
+ dependencies , sagemaker_session , tmp ):
368
+ code_dir = os .path .join (model_dir , 'code' )
369
+ if os .path .exists (code_dir ):
370
+ shutil .rmtree (code_dir , ignore_errors = True )
371
+ if source_directory and source_directory .lower ().startswith ('s3://' ):
372
+ local_code_path = os .path .join (tmp , 'local_code.tar.gz' )
373
+ download_file_from_url (source_directory , local_code_path , sagemaker_session )
374
+
375
+ with tarfile .open (name = local_code_path , mode = 'r:gz' ) as t :
376
+ t .extractall (path = code_dir )
358
377
359
- if model_from_s3 :
360
- url = parse .urlparse (model_uri )
361
- bucket , key = url .netloc , url .path .lstrip ('/' )
362
- new_key = key .replace (os .path .basename (key ), new_model_name )
378
+ elif source_directory :
379
+ shutil .copytree (source_directory , code_dir )
380
+ else :
381
+ os .mkdir (code_dir )
382
+ shutil .copy2 (inference_script , code_dir )
363
383
364
- sagemaker_session .boto_session .resource ('s3' ).Object (bucket , new_key ).upload_file (new_model_path )
365
- return 's3://%s/%s' % (bucket , new_key )
384
+ for dependency in dependencies :
385
+ if os .path .isdir (dependency ):
386
+ shutil .copytree (dependency , code_dir )
366
387
else :
367
- return 'file://%s' % new_model_path
388
+ shutil .copy2 (dependency , code_dir )
389
+
390
+
391
+ def _extract_model (model_uri , sagemaker_session , tmp ):
392
+ tmp_model_dir = os .path .join (tmp , 'model' )
393
+ os .mkdir (tmp_model_dir )
394
+ if model_uri .lower ().startswith ('s3://' ):
395
+ local_model_path = os .path .join (tmp , 'tar_file' )
396
+ download_file_from_url (model_uri , local_model_path , sagemaker_session )
397
+ else :
398
+ local_model_path = model_uri .replace ('file://' , '' )
399
+ with tarfile .open (name = local_model_path , mode = 'r:gz' ) as t :
400
+ t .extractall (path = tmp_model_dir )
401
+ return tmp_model_dir
368
402
369
403
370
404
def download_file_from_url (url , dst , sagemaker_session ):
0 commit comments