Skip to content

fix: repack_model script used in pipelines to support source_dir and dependencies #2645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 12, 2021
Merged
62 changes: 51 additions & 11 deletions src/sagemaker/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@
from distutils.dir_util import copy_tree


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_script", type=str, default="inference.py")
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
args = parser.parse_args()
def repack(inference_script, model_archive, dependencies=None, source_dir=None):
"""Repack custom dependencies and code into an existing model TAR archive

Args:
inference_script (str): The path to the custom entry point.
model_archive (str): The name of the model TAR archive.
dependencies (str): A space-delimited string of paths to custom dependencies.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit this description makes more sense on a cli argument than a method parameter

source_dir (str): The path to a custom source directory.
"""

# the data directory contains a model archive generated by a previous training job
data_directory = "/opt/ml/input/data/training"
model_path = os.path.join(data_directory, args.model_archive)
model_path = os.path.join(data_directory, model_archive)

# create a temporary directory
with tempfile.TemporaryDirectory() as tmp:
Expand All @@ -51,17 +55,53 @@
shutil.copy2(model_path, local_path)
src_dir = os.path.join(tmp, "src")
# create the "code" directory which will contain the inference script
os.makedirs(os.path.join(src_dir, "code"))
code_dir = os.path.join(src_dir, "code")
os.makedirs(code_dir)
# extract the contents of the previous training job's model archive to the "src"
# directory of this training job
with tarfile.open(name=local_path, mode="r:gz") as tf:
tf.extractall(path=src_dir)

# generate a path to the custom inference script
entry_point = os.path.join("/opt/ml/code", args.inference_script)
# copy the custom inference script to the "src" dir
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
# copy the custom inference script to code/
entry_point = os.path.join("/opt/ml/code", inference_script)
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script))

# copy source_dir to code/
if source_dir:
if os.path.exists(code_dir):
shutil.rmtree(code_dir)
shutil.copytree(source_dir, code_dir)

# copy any dependencies to code/lib/
if dependencies:
for dependency in dependencies.split(" "):
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
lib_dir = os.path.join(code_dir, "lib")
if not os.path.exists(lib_dir):
os.mkdir(lib_dir)
if os.path.isdir(actual_dependency_path):
shutil.copytree(
actual_dependency_path,
os.path.join(lib_dir, os.path.basename(actual_dependency_path)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is os.path.join needed here but not on line 77?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.path.join is being used for both calls, but on line 77 it's being done in the assignment of the s and d variables

)
else:
shutil.copy2(actual_dependency_path, lib_dir)

# copy the "src" dir, which includes the previous training job's model and the
# custom inference script, to the output of this training job
copy_tree(src_dir, "/opt/ml/model")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inference_script", type=str, default="inference.py")
parser.add_argument("--dependencies", type=str, default=None)
parser.add_argument("--source_dir", type=str, default=None)
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
Comment on lines +97 to +100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also has a help param if that would benefit users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. I don't think it would be much use since this script is intended for, and will only work in, a training job

args, extra = parser.parse_known_args()
repack(
inference_script=args.inference_script,
dependencies=args.dependencies,
source_dir=args.source_dir,
model_archive=args.model_archive,
)
7 changes: 7 additions & 0 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ def __init__(
self._source_dir = source_dir
self._dependencies = dependencies

# convert dependencies array into space-delimited string
dependencies_hyperparameter = None
if self._dependencies:
dependencies_hyperparameter = " ".join(self._dependencies)

# the real estimator and inputs
repacker = SKLearn(
framework_version=FRAMEWORK_VERSION,
Expand All @@ -157,6 +162,8 @@ def __init__(
hyperparameters={
"inference_script": self._entry_point_basename,
"model_archive": self._model_archive,
"dependencies": dependencies_hyperparameter,
"source_dir": self._source_dir,
},
subnets=subnets,
security_group_ids=security_group_ids,
Expand Down
174 changes: 174 additions & 0 deletions tests/unit/sagemaker/workflow/test_repack_model_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from sagemaker.workflow import _repack_model

from pathlib import Path
import shutil
import tarfile
import os
import pytest
import time


@pytest.mark.skip(
reason="""This test operates on the root file system
and will likely fail due to permission errors.
Temporarily remove this skip decorator and run
the test after making changes to _repack_model.py"""
)
def test_repack_entry_point_only(tmp):
model_name = "xg-boost-model"
fake_model_path = os.path.join(tmp, model_name)

# create a fake model
open(fake_model_path, "w")

# create model.tar.gz
model_tar_name = "model-%s.tar.gz" % time.time()
model_tar_location = os.path.join(tmp, model_tar_name)
with tarfile.open(model_tar_location, mode="w:gz") as t:
t.add(fake_model_path, arcname=model_name)

# move model.tar.gz to /opt/ml/input/data/training
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))

# create files that will be added to model.tar.gz
create_file_tree(
"/opt/ml/code",
[
"inference.py",
],
)

# repack
_repack_model.repack(inference_script="inference.py", model_archive=model_tar_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, would be nice to also cover __main__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the reasons I moved the logic to a function was because I couldn't figure out how to invoke the __main__ of a module from within python :) As far as I understand,__name__ is only set to __main__ if it's called from a command line. Suggestions welcome


# /opt/ml/model should now have the original model and the inference script
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))


@pytest.mark.skip(
reason="""This test operates on the root file system
and will likely fail due to permission errors.
Temporarily remove this skip decorator and run
the test after making changes to _repack_model.py"""
)
def test_repack_with_dependencies(tmp):
model_name = "xg-boost-model"
fake_model_path = os.path.join(tmp, model_name)

# create a fake model
open(fake_model_path, "w")

# create model.tar.gz
model_tar_name = "model-%s.tar.gz" % time.time()
model_tar_location = os.path.join(tmp, model_tar_name)
with tarfile.open(model_tar_location, mode="w:gz") as t:
t.add(fake_model_path, arcname=model_name)

# move model.tar.gz to /opt/ml/input/data/training
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))

# create files that will be added to model.tar.gz
create_file_tree(
"/opt/ml/code",
["inference.py", "dependencies/a", "bb", "dependencies/some/dir/b"],
)

# repack
_repack_model.repack(
inference_script="inference.py",
model_archive=model_tar_name,
dependencies=["dependencies/a", "bb", "dependencies/some/dir"],
)

# /opt/ml/model should now have the original model and the inference script
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a"))
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb"))
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b"))


@pytest.mark.skip(
reason="""This test operates on the root file system
and will likely fail due to permission errors.
Temporarily remove this skip decorator and run
the test after making changes to _repack_model.py"""
)
def test_repack_with_source_dir_and_dependencies(tmp):
model_name = "xg-boost-model"
fake_model_path = os.path.join(tmp, model_name)

# create a fake model
open(fake_model_path, "w")

# create model.tar.gz
model_tar_name = "model-%s.tar.gz" % time.time()
model_tar_location = os.path.join(tmp, model_tar_name)
with tarfile.open(model_tar_location, mode="w:gz") as t:
t.add(fake_model_path, arcname=model_name)

# move model.tar.gz to /opt/ml/input/data/training
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))

# create files that will be added to model.tar.gz
create_file_tree(
"/opt/ml/code",
[
"inference.py",
"dependencies/a",
"bb",
"dependencies/some/dir/b",
"sourcedir/foo.py",
"sourcedir/some/dir/a",
],
)

# repack
_repack_model.repack(
inference_script="inference.py",
model_archive=model_tar_name,
dependencies=["dependencies/a", "bb", "dependencies/some/dir"],
source_dir="sourcedir",
)

# /opt/ml/model should now have the original model and the inference script
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a"))
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb"))
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b"))
assert os.path.exists(os.path.join("/opt/ml/model/code/", "foo.py"))
assert os.path.exists(os.path.join("/opt/ml/model/code/some/dir", "a"))


def create_file_tree(root, tree):
for file in tree:
try:
os.makedirs(os.path.join(root, os.path.dirname(file)))
except: # noqa: E722 Using bare except because p2/3 incompatibility issues.
pass
with open(os.path.join(root, file), "a") as f:
f.write(file)


@pytest.fixture()
def tmp(tmpdir):
yield str(tmpdir)
8 changes: 8 additions & 0 deletions tests/unit/sagemaker/workflow/test_step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def test_register_model_sip(estimator, model_metrics):

def test_register_model_with_model_repack_with_estimator(estimator, model_metrics):
model_data = f"s3://{BUCKET}/model.tar.gz"
dummy_requirements = f"{DATA_DIR}/dummy_requirements.txt"
register_model = RegisterModel(
name="RegisterModelStep",
estimator=estimator,
Expand All @@ -379,6 +380,7 @@ def test_register_model_with_model_repack_with_estimator(estimator, model_metric
approval_status="Approved",
description="description",
entry_point=f"{DATA_DIR}/dummy_script.py",
dependencies=[dummy_requirements],
depends_on=["TestStep"],
tags=[{"Key": "myKey", "Value": "myValue"}],
)
Expand All @@ -405,6 +407,7 @@ def test_register_model_with_model_repack_with_estimator(estimator, model_metric
},
"HyperParameters": {
"inference_script": '"dummy_script.py"',
"dependencies": f'"{dummy_requirements}"',
"model_archive": '"model.tar.gz"',
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
BUCKET, repacker_job_name.replace('"', "")
Expand All @@ -413,6 +416,7 @@ def test_register_model_with_model_repack_with_estimator(estimator, model_metric
"sagemaker_container_log_level": "20",
"sagemaker_job_name": repacker_job_name,
"sagemaker_region": f'"{REGION}"',
"source_dir": "null",
},
"InputDataConfig": [
{
Expand Down Expand Up @@ -528,6 +532,8 @@ def test_register_model_with_model_repack_with_model(model, model_metrics):
"sagemaker_container_log_level": "20",
"sagemaker_job_name": repacker_job_name,
"sagemaker_region": f'"{REGION}"',
"dependencies": "null",
"source_dir": "null",
},
"InputDataConfig": [
{
Expand Down Expand Up @@ -631,6 +637,7 @@ def test_register_model_with_model_repack_with_pipeline_model(pipeline_model, mo
"S3OutputPath": f"s3://{BUCKET}/",
},
"HyperParameters": {
"dependencies": "null",
"inference_script": '"dummy_script.py"',
"model_archive": '"model.tar.gz"',
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
Expand All @@ -640,6 +647,7 @@ def test_register_model_with_model_repack_with_pipeline_model(pipeline_model, mo
"sagemaker_container_log_level": "20",
"sagemaker_job_name": repacker_job_name,
"sagemaker_region": f'"{REGION}"',
"source_dir": "null",
},
"InputDataConfig": [
{
Expand Down