|
34 | 34 | from distutils.dir_util import copy_tree
|
35 | 35 |
|
36 | 36 |
|
37 |
| -if __name__ == "__main__": |
38 |
| - parser = argparse.ArgumentParser() |
39 |
| - parser.add_argument("--inference_script", type=str, default="inference.py") |
40 |
| - parser.add_argument("--model_archive", type=str, default="model.tar.gz") |
41 |
| - args = parser.parse_args() |
| 37 | +def repack(inference_script, model_archive, dependencies=None, source_dir=None): |
| 38 | + """Repack custom dependencies and code into an existing model TAR archive |
| 39 | +
|
| 40 | + Args: |
| 41 | + inference_script (str): The path to the custom entry point. |
| 42 | + model_archive (str): The name of the model TAR archive. |
| 43 | + dependencies (str): A space-delimited string of paths to custom dependencies. |
| 44 | + source_dir (str): The path to a custom source directory. |
| 45 | + """ |
42 | 46 |
|
43 | 47 | # the data directory contains a model archive generated by a previous training job
|
44 | 48 | data_directory = "/opt/ml/input/data/training"
|
45 |
| - model_path = os.path.join(data_directory, args.model_archive) |
| 49 | + model_path = os.path.join(data_directory, model_archive) |
46 | 50 |
|
47 | 51 | # create a temporary directory
|
48 | 52 | with tempfile.TemporaryDirectory() as tmp:
|
|
51 | 55 | shutil.copy2(model_path, local_path)
|
52 | 56 | src_dir = os.path.join(tmp, "src")
|
53 | 57 | # create the "code" directory which will contain the inference script
|
54 |
| - os.makedirs(os.path.join(src_dir, "code")) |
| 58 | + code_dir = os.path.join(src_dir, "code") |
| 59 | + os.makedirs(code_dir) |
55 | 60 | # extract the contents of the previous training job's model archive to the "src"
|
56 | 61 | # directory of this training job
|
57 | 62 | with tarfile.open(name=local_path, mode="r:gz") as tf:
|
58 | 63 | tf.extractall(path=src_dir)
|
59 | 64 |
|
60 |
| - # generate a path to the custom inference script |
61 |
| - entry_point = os.path.join("/opt/ml/code", args.inference_script) |
62 |
| - # copy the custom inference script to the "src" dir |
63 |
| - shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script)) |
| 65 | + # copy the custom inference script to code/ |
| 66 | + entry_point = os.path.join("/opt/ml/code", inference_script) |
| 67 | + shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script)) |
| 68 | + |
| 69 | + # copy source_dir to code/ |
| 70 | + if source_dir: |
| 71 | + if os.path.exists(code_dir): |
| 72 | + shutil.rmtree(code_dir) |
| 73 | + shutil.copytree(source_dir, code_dir) |
| 74 | + |
| 75 | + # copy any dependencies to code/lib/ |
| 76 | + if dependencies: |
| 77 | + for dependency in dependencies.split(" "): |
| 78 | + actual_dependency_path = os.path.join("/opt/ml/code", dependency) |
| 79 | + lib_dir = os.path.join(code_dir, "lib") |
| 80 | + if not os.path.exists(lib_dir): |
| 81 | + os.mkdir(lib_dir) |
| 82 | + if os.path.isdir(actual_dependency_path): |
| 83 | + shutil.copytree( |
| 84 | + actual_dependency_path, |
| 85 | + os.path.join(lib_dir, os.path.basename(actual_dependency_path)), |
| 86 | + ) |
| 87 | + else: |
| 88 | + shutil.copy2(actual_dependency_path, lib_dir) |
64 | 89 |
|
65 | 90 | # copy the "src" dir, which includes the previous training job's model and the
|
66 | 91 | # custom inference script, to the output of this training job
|
67 | 92 | copy_tree(src_dir, "/opt/ml/model")
|
| 93 | + |
| 94 | + |
| 95 | +if __name__ == "__main__": |
| 96 | + parser = argparse.ArgumentParser() |
| 97 | + parser.add_argument("--inference_script", type=str, default="inference.py") |
| 98 | + parser.add_argument("--dependencies", type=str, default=None) |
| 99 | + parser.add_argument("--source_dir", type=str, default=None) |
| 100 | + parser.add_argument("--model_archive", type=str, default="model.tar.gz") |
| 101 | + args, extra = parser.parse_known_args() |
| 102 | + repack( |
| 103 | + inference_script=args.inference_script, |
| 104 | + dependencies=args.dependencies, |
| 105 | + source_dir=args.source_dir, |
| 106 | + model_archive=args.model_archive, |
| 107 | + ) |
0 commit comments