1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Repack model script for training jobs to inject entry points"""
14
+ from __future__ import absolute_import
15
+
16
+ import argparse
17
+ import os
18
+ import shutil
19
+ import tarfile
20
+ import tempfile
21
+
22
+ # Repack Model
23
+ # The following script is run via a training job which takes an existing model and a custom
24
+ # entry point script as arguments. The script creates a new model archive with the custom
25
+ # entry point in the "code" directory along with the existing model. Subsequently, when the model
26
+ # is unpacked for inference, the custom entry point will be used.
27
+ # Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28
+
29
+ # distutils.dir_util.copy_tree works way better than the half-baked
30
+ # shutil.copytree which bombs on previously existing target dirs...
31
+ # alas ... https://bugs.python.org/issue10948
32
+ # we'll go ahead and use the copy_tree function anyways because this
33
+ # repacking is some short-lived hackery, right??
34
+ from distutils .dir_util import copy_tree
35
+
36
+
37
+ if __name__ == "__main__" :
38
+ parser = argparse .ArgumentParser ()
39
+ parser .add_argument ("--inference_script" , type = str , default = "inference.py" )
40
+ parser .add_argument ("--model_archive" , type = str , default = "model.tar.gz" )
41
+ args = parser .parse_args ()
42
+
43
+ # the data directory contains a model archive generated by a previous training job
44
+ data_directory = "/opt/ml/input/data/training"
45
+ model_path = os .path .join (data_directory , args .model_archive )
46
+
47
+ # create a temporary directory
48
+ with tempfile .TemporaryDirectory () as tmp :
49
+ local_path = os .path .join (tmp , "local.tar.gz" )
50
+ # copy the previous training job's model archive to the temporary directory
51
+ shutil .copy2 (model_path , local_path )
52
+ src_dir = os .path .join (tmp , "src" )
53
+ # create the "code" directory which will contain the inference script
54
+ os .makedirs (os .path .join (src_dir , "code" ))
55
+ # extract the contents of the previous training job's model archive to the "src"
56
+ # directory of this training job
57
+ with tarfile .open (name = local_path , mode = "r:gz" ) as tf :
58
+ tf .extractall (path = src_dir )
59
+
60
+ # generate a path to the custom inference script
61
+ entry_point = os .path .join ("/opt/ml/code" , args .inference_script )
62
+ # copy the custom inference script to the "src" dir
63
+ shutil .copy2 (entry_point , os .path .join (src_dir , "code" , args .inference_script ))
64
+
65
+ # copy the "src" dir, which includes the previous training job's model and the
66
+ # custom inference script, to the output of this training job
67
+ copy_tree (src_dir , "/opt/ml/model" )
0 commit comments