-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: repack_model script used in pipelines to support source_dir and dependencies #2645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5b16b4f
db68957
ecbc5e5
bf0899b
1326181
70e22f5
0d38b3a
7ffed8c
c6c1f0a
2e469ab
4bc1550
2dedfd3
82d795e
558cdfe
e472fbe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,15 +34,19 @@ | |
from distutils.dir_util import copy_tree | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--inference_script", type=str, default="inference.py") | ||
parser.add_argument("--model_archive", type=str, default="model.tar.gz") | ||
args = parser.parse_args() | ||
def repack(inference_script, model_archive, dependencies=None, source_dir=None): | ||
"""Repack custom dependencies and code into an existing model TAR archive | ||
|
||
Args: | ||
inference_script (str): The path to the custom entry point. | ||
model_archive (str): The name of the model TAR archive. | ||
dependencies (str): A space-delimited string of paths to custom dependencies. | ||
source_dir (str): The path to a custom source directory. | ||
""" | ||
|
||
# the data directory contains a model archive generated by a previous training job | ||
data_directory = "/opt/ml/input/data/training" | ||
model_path = os.path.join(data_directory, args.model_archive) | ||
model_path = os.path.join(data_directory, model_archive) | ||
|
||
# create a temporary directory | ||
with tempfile.TemporaryDirectory() as tmp: | ||
|
@@ -51,17 +55,53 @@ | |
shutil.copy2(model_path, local_path) | ||
src_dir = os.path.join(tmp, "src") | ||
# create the "code" directory which will contain the inference script | ||
os.makedirs(os.path.join(src_dir, "code")) | ||
code_dir = os.path.join(src_dir, "code") | ||
os.makedirs(code_dir) | ||
# extract the contents of the previous training job's model archive to the "src" | ||
# directory of this training job | ||
with tarfile.open(name=local_path, mode="r:gz") as tf: | ||
tf.extractall(path=src_dir) | ||
|
||
# generate a path to the custom inference script | ||
entry_point = os.path.join("/opt/ml/code", args.inference_script) | ||
# copy the custom inference script to the "src" dir | ||
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script)) | ||
# copy the custom inference script to code/ | ||
entry_point = os.path.join("/opt/ml/code", inference_script) | ||
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script)) | ||
|
||
# copy source_dir to code/ | ||
if source_dir: | ||
if os.path.exists(code_dir): | ||
shutil.rmtree(code_dir) | ||
shutil.copytree(source_dir, code_dir) | ||
|
||
# copy any dependencies to code/lib/ | ||
if dependencies: | ||
for dependency in dependencies.split(" "): | ||
actual_dependency_path = os.path.join("/opt/ml/code", dependency) | ||
lib_dir = os.path.join(code_dir, "lib") | ||
if not os.path.exists(lib_dir): | ||
os.mkdir(lib_dir) | ||
if os.path.isdir(actual_dependency_path): | ||
shutil.copytree( | ||
actual_dependency_path, | ||
os.path.join(lib_dir, os.path.basename(actual_dependency_path)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
else: | ||
shutil.copy2(actual_dependency_path, lib_dir) | ||
|
||
# copy the "src" dir, which includes the previous training job's model and the | ||
# custom inference script, to the output of this training job | ||
copy_tree(src_dir, "/opt/ml/model") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--inference_script", type=str, default="inference.py") | ||
parser.add_argument("--dependencies", type=str, default=None) | ||
parser.add_argument("--source_dir", type=str, default=None) | ||
parser.add_argument("--model_archive", type=str, default="model.tar.gz") | ||
Comment on lines
+97
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this also has a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack. I don't think it would be much use since this script is intended for, and will only work in, a training job |
||
args, extra = parser.parse_known_args() | ||
repack( | ||
inference_script=args.inference_script, | ||
dependencies=args.dependencies, | ||
source_dir=args.source_dir, | ||
model_archive=args.model_archive, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
from sagemaker.workflow import _repack_model | ||
|
||
from pathlib import Path | ||
import shutil | ||
import tarfile | ||
import os | ||
import pytest | ||
import time | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="""This test operates on the root file system | ||
and will likely fail due to permission errors. | ||
Temporarily remove this skip decorator and run | ||
the test after making changes to _repack_model.py""" | ||
) | ||
def test_repack_entry_point_only(tmp): | ||
model_name = "xg-boost-model" | ||
fake_model_path = os.path.join(tmp, model_name) | ||
|
||
# create a fake model | ||
open(fake_model_path, "w") | ||
|
||
# create model.tar.gz | ||
model_tar_name = "model-%s.tar.gz" % time.time() | ||
model_tar_location = os.path.join(tmp, model_tar_name) | ||
with tarfile.open(model_tar_location, mode="w:gz") as t: | ||
t.add(fake_model_path, arcname=model_name) | ||
|
||
# move model.tar.gz to /opt/ml/input/data/training | ||
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True) | ||
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name)) | ||
|
||
# create files that will be added to model.tar.gz | ||
create_file_tree( | ||
"/opt/ml/code", | ||
[ | ||
"inference.py", | ||
], | ||
) | ||
|
||
# repack | ||
_repack_model.repack(inference_script="inference.py", model_archive=model_tar_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, would be nice to also cover There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the reasons I moved the logic to a function was because I couldn't figure out how to invoke the |
||
|
||
# /opt/ml/model should now have the original model and the inference script | ||
assert os.path.exists(os.path.join("/opt/ml/model", model_name)) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py")) | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="""This test operates on the root file system | ||
and will likely fail due to permission errors. | ||
Temporarily remove this skip decorator and run | ||
the test after making changes to _repack_model.py""" | ||
) | ||
def test_repack_with_dependencies(tmp): | ||
model_name = "xg-boost-model" | ||
fake_model_path = os.path.join(tmp, model_name) | ||
|
||
# create a fake model | ||
open(fake_model_path, "w") | ||
|
||
# create model.tar.gz | ||
model_tar_name = "model-%s.tar.gz" % time.time() | ||
model_tar_location = os.path.join(tmp, model_tar_name) | ||
with tarfile.open(model_tar_location, mode="w:gz") as t: | ||
t.add(fake_model_path, arcname=model_name) | ||
|
||
# move model.tar.gz to /opt/ml/input/data/training | ||
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True) | ||
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name)) | ||
|
||
# create files that will be added to model.tar.gz | ||
create_file_tree( | ||
"/opt/ml/code", | ||
["inference.py", "dependencies/a", "bb", "dependencies/some/dir/b"], | ||
) | ||
|
||
# repack | ||
_repack_model.repack( | ||
inference_script="inference.py", | ||
model_archive=model_tar_name, | ||
dependencies=["dependencies/a", "bb", "dependencies/some/dir"], | ||
) | ||
|
||
# /opt/ml/model should now have the original model and the inference script | ||
assert os.path.exists(os.path.join("/opt/ml/model", model_name)) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b")) | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="""This test operates on the root file system | ||
and will likely fail due to permission errors. | ||
Temporarily remove this skip decorator and run | ||
the test after making changes to _repack_model.py""" | ||
) | ||
def test_repack_with_source_dir_and_dependencies(tmp): | ||
model_name = "xg-boost-model" | ||
fake_model_path = os.path.join(tmp, model_name) | ||
|
||
# create a fake model | ||
open(fake_model_path, "w") | ||
|
||
# create model.tar.gz | ||
model_tar_name = "model-%s.tar.gz" % time.time() | ||
model_tar_location = os.path.join(tmp, model_tar_name) | ||
with tarfile.open(model_tar_location, mode="w:gz") as t: | ||
t.add(fake_model_path, arcname=model_name) | ||
|
||
# move model.tar.gz to /opt/ml/input/data/training | ||
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True) | ||
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name)) | ||
|
||
# create files that will be added to model.tar.gz | ||
create_file_tree( | ||
"/opt/ml/code", | ||
[ | ||
"inference.py", | ||
"dependencies/a", | ||
"bb", | ||
"dependencies/some/dir/b", | ||
"sourcedir/foo.py", | ||
"sourcedir/some/dir/a", | ||
], | ||
) | ||
|
||
# repack | ||
_repack_model.repack( | ||
inference_script="inference.py", | ||
model_archive=model_tar_name, | ||
dependencies=["dependencies/a", "bb", "dependencies/some/dir"], | ||
source_dir="sourcedir", | ||
) | ||
|
||
# /opt/ml/model should now have the original model and the inference script | ||
assert os.path.exists(os.path.join("/opt/ml/model", model_name)) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/", "foo.py")) | ||
assert os.path.exists(os.path.join("/opt/ml/model/code/some/dir", "a")) | ||
|
||
|
||
def create_file_tree(root, tree): | ||
for file in tree: | ||
try: | ||
os.makedirs(os.path.join(root, os.path.dirname(file))) | ||
except: # noqa: E722 Using bare except because p2/3 incompatibility issues. | ||
pass | ||
with open(os.path.join(root, file), "a") as f: | ||
f.write(file) | ||
|
||
|
||
@pytest.fixture() | ||
def tmp(tmpdir): | ||
yield str(tmpdir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit this description makes more sense on a cli argument than a method parameter