|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""Repack model script for training jobs to inject entry points""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +import argparse |
| 17 | +import os |
| 18 | +import shutil |
| 19 | +import tarfile |
| 20 | +import tempfile |
| 21 | + |
| 22 | +# Repack Model |
| 23 | +# The following script is run via a training job which takes an existing model and a custom |
| 24 | +# entry point script as arguments. The script creates a new model archive with the custom |
| 25 | +# entry point in the "code" directory along with the existing model. Subsequently, when the model |
| 26 | +# is unpacked for inference, the custom entry point will be used. |
| 27 | +# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html |
| 28 | + |
| 29 | +# distutils.dir_util.copy_tree works way better than the half-baked |
| 30 | +# shutil.copytree which bombs on previously existing target dirs... |
| 31 | +# alas ... https://bugs.python.org/issue10948 |
| 32 | +# we'll go ahead and use the copy_tree function anyways because this |
| 33 | +# repacking is some short-lived hackery, right?? |
| 34 | +from distutils.dir_util import copy_tree |
| 35 | + |
| 36 | + |
| 37 | +def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover |
| 38 | + """Repack custom dependencies and code into an existing model TAR archive |
| 39 | +
|
| 40 | + Args: |
| 41 | + inference_script (str): The path to the custom entry point. |
| 42 | + model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive. |
| 43 | + dependencies (str): A space-delimited string of paths to custom dependencies. |
| 44 | + source_dir (str): The path to a custom source directory. |
| 45 | + """ |
| 46 | + |
| 47 | + # the data directory contains a model archive generated by a previous training job |
| 48 | + data_directory = "/opt/ml/input/data/training" |
| 49 | + model_path = os.path.join(data_directory, model_archive.split("/")[-1]) |
| 50 | + |
| 51 | + # create a temporary directory |
| 52 | + with tempfile.TemporaryDirectory() as tmp: |
| 53 | + local_path = os.path.join(tmp, "local.tar.gz") |
| 54 | + # copy the previous training job's model archive to the temporary directory |
| 55 | + shutil.copy2(model_path, local_path) |
| 56 | + src_dir = os.path.join(tmp, "src") |
| 57 | + # create the "code" directory which will contain the inference script |
| 58 | + code_dir = os.path.join(src_dir, "code") |
| 59 | + os.makedirs(code_dir) |
| 60 | + # extract the contents of the previous training job's model archive to the "src" |
| 61 | + # directory of this training job |
| 62 | + with tarfile.open(name=local_path, mode="r:gz") as tf: |
| 63 | + tf.extractall(path=src_dir) |
| 64 | + |
| 65 | + if source_dir: |
| 66 | + # copy /opt/ml/code to code/ |
| 67 | + if os.path.exists(code_dir): |
| 68 | + shutil.rmtree(code_dir) |
| 69 | + shutil.copytree("/opt/ml/code", code_dir) |
| 70 | + else: |
| 71 | + # copy the custom inference script to code/ |
| 72 | + entry_point = os.path.join("/opt/ml/code", inference_script) |
| 73 | + shutil.copy2(entry_point, os.path.join(code_dir, inference_script)) |
| 74 | + |
| 75 | + # copy any dependencies to code/lib/ |
| 76 | + if dependencies: |
| 77 | + for dependency in dependencies.split(" "): |
| 78 | + actual_dependency_path = os.path.join("/opt/ml/code", dependency) |
| 79 | + lib_dir = os.path.join(code_dir, "lib") |
| 80 | + if not os.path.exists(lib_dir): |
| 81 | + os.mkdir(lib_dir) |
| 82 | + if os.path.isfile(actual_dependency_path): |
| 83 | + shutil.copy2(actual_dependency_path, lib_dir) |
| 84 | + else: |
| 85 | + if os.path.exists(lib_dir): |
| 86 | + shutil.rmtree(lib_dir) |
| 87 | + # a directory is in the dependencies. we have to copy |
| 88 | + # all of /opt/ml/code into the lib dir because the original directory |
| 89 | + # was flattened by the SDK training job upload.. |
| 90 | + shutil.copytree("/opt/ml/code", lib_dir) |
| 91 | + break |
| 92 | + |
| 93 | + # copy the "src" dir, which includes the previous training job's model and the |
| 94 | + # custom inference script, to the output of this training job |
| 95 | + copy_tree(src_dir, "/opt/ml/model") |
| 96 | + |
| 97 | + |
| 98 | +if __name__ == "__main__": # pragma: no cover |
| 99 | + parser = argparse.ArgumentParser() |
| 100 | + parser.add_argument("--inference_script", type=str, default="inference.py") |
| 101 | + parser.add_argument("--dependencies", type=str, default=None) |
| 102 | + parser.add_argument("--source_dir", type=str, default=None) |
| 103 | + parser.add_argument("--model_archive", type=str, default="model.tar.gz") |
| 104 | + args, extra = parser.parse_known_args() |
| 105 | + repack( |
| 106 | + inference_script=args.inference_script, |
| 107 | + dependencies=args.dependencies, |
| 108 | + source_dir=args.source_dir, |
| 109 | + model_archive=args.model_archive, |
| 110 | + ) |
0 commit comments