Skip to content

Commit 50840db

Browse files
staubhpPayton Staubshreyapanditahsan-z-khanjeniyat
authored
fix: repack_model script used in pipelines to support source_dir and dependencies (#2645)
Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Jeniya Tabassum <[email protected]>
1 parent 4fe5ea4 commit 50840db

File tree

4 files changed

+240
-11
lines changed

4 files changed

+240
-11
lines changed

src/sagemaker/workflow/_repack_model.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@
3434
from distutils.dir_util import copy_tree
3535

3636

37-
if __name__ == "__main__":
38-
parser = argparse.ArgumentParser()
39-
parser.add_argument("--inference_script", type=str, default="inference.py")
40-
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
41-
args = parser.parse_args()
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None):
38+
"""Repack custom dependencies and code into an existing model TAR archive
39+
40+
Args:
41+
inference_script (str): The path to the custom entry point.
42+
model_archive (str): The name of the model TAR archive.
43+
dependencies (str): A space-delimited string of paths to custom dependencies.
44+
source_dir (str): The path to a custom source directory.
45+
"""
4246

4347
# the data directory contains a model archive generated by a previous training job
4448
data_directory = "/opt/ml/input/data/training"
45-
model_path = os.path.join(data_directory, args.model_archive)
49+
model_path = os.path.join(data_directory, model_archive)
4650

4751
# create a temporary directory
4852
with tempfile.TemporaryDirectory() as tmp:
@@ -51,17 +55,53 @@
5155
shutil.copy2(model_path, local_path)
5256
src_dir = os.path.join(tmp, "src")
5357
# create the "code" directory which will contain the inference script
54-
os.makedirs(os.path.join(src_dir, "code"))
58+
code_dir = os.path.join(src_dir, "code")
59+
os.makedirs(code_dir)
5560
# extract the contents of the previous training job's model archive to the "src"
5661
# directory of this training job
5762
with tarfile.open(name=local_path, mode="r:gz") as tf:
5863
tf.extractall(path=src_dir)
5964

60-
# generate a path to the custom inference script
61-
entry_point = os.path.join("/opt/ml/code", args.inference_script)
62-
# copy the custom inference script to the "src" dir
63-
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
65+
# copy the custom inference script to code/
66+
entry_point = os.path.join("/opt/ml/code", inference_script)
67+
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script))
68+
69+
# copy source_dir to code/
70+
if source_dir:
71+
if os.path.exists(code_dir):
72+
shutil.rmtree(code_dir)
73+
shutil.copytree(source_dir, code_dir)
74+
75+
# copy any dependencies to code/lib/
76+
if dependencies:
77+
for dependency in dependencies.split(" "):
78+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
79+
lib_dir = os.path.join(code_dir, "lib")
80+
if not os.path.exists(lib_dir):
81+
os.mkdir(lib_dir)
82+
if os.path.isdir(actual_dependency_path):
83+
shutil.copytree(
84+
actual_dependency_path,
85+
os.path.join(lib_dir, os.path.basename(actual_dependency_path)),
86+
)
87+
else:
88+
shutil.copy2(actual_dependency_path, lib_dir)
6489

6590
# copy the "src" dir, which includes the previous training job's model and the
6691
# custom inference script, to the output of this training job
6792
copy_tree(src_dir, "/opt/ml/model")
93+
94+
95+
if __name__ == "__main__":
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--inference_script", type=str, default="inference.py")
98+
parser.add_argument("--dependencies", type=str, default=None)
99+
parser.add_argument("--source_dir", type=str, default=None)
100+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
101+
args, extra = parser.parse_known_args()
102+
repack(
103+
inference_script=args.inference_script,
104+
dependencies=args.dependencies,
105+
source_dir=args.source_dir,
106+
model_archive=args.model_archive,
107+
)

src/sagemaker/workflow/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ def __init__(
145145
self._source_dir = source_dir
146146
self._dependencies = dependencies
147147

148+
# convert dependencies array into space-delimited string
149+
dependencies_hyperparameter = None
150+
if self._dependencies:
151+
dependencies_hyperparameter = " ".join(self._dependencies)
152+
148153
# the real estimator and inputs
149154
repacker = SKLearn(
150155
framework_version=FRAMEWORK_VERSION,
@@ -157,6 +162,8 @@ def __init__(
157162
hyperparameters={
158163
"inference_script": self._entry_point_basename,
159164
"model_archive": self._model_archive,
165+
"dependencies": dependencies_hyperparameter,
166+
"source_dir": self._source_dir,
160167
},
161168
subnets=subnets,
162169
security_group_ids=security_group_ids,
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
from __future__ import absolute_import
15+
from sagemaker.workflow import _repack_model
16+
17+
from pathlib import Path
18+
import shutil
19+
import tarfile
20+
import os
21+
import pytest
22+
import time
23+
24+
25+
@pytest.mark.skip(
26+
reason="""This test operates on the root file system
27+
and will likely fail due to permission errors.
28+
Temporarily remove this skip decorator and run
29+
the test after making changes to _repack_model.py"""
30+
)
31+
def test_repack_entry_point_only(tmp):
32+
model_name = "xg-boost-model"
33+
fake_model_path = os.path.join(tmp, model_name)
34+
35+
# create a fake model
36+
open(fake_model_path, "w")
37+
38+
# create model.tar.gz
39+
model_tar_name = "model-%s.tar.gz" % time.time()
40+
model_tar_location = os.path.join(tmp, model_tar_name)
41+
with tarfile.open(model_tar_location, mode="w:gz") as t:
42+
t.add(fake_model_path, arcname=model_name)
43+
44+
# move model.tar.gz to /opt/ml/input/data/training
45+
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
46+
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))
47+
48+
# create files that will be added to model.tar.gz
49+
create_file_tree(
50+
"/opt/ml/code",
51+
[
52+
"inference.py",
53+
],
54+
)
55+
56+
# repack
57+
_repack_model.repack(inference_script="inference.py", model_archive=model_tar_name)
58+
59+
# /opt/ml/model should now have the original model and the inference script
60+
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
61+
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
62+
63+
64+
@pytest.mark.skip(
65+
reason="""This test operates on the root file system
66+
and will likely fail due to permission errors.
67+
Temporarily remove this skip decorator and run
68+
the test after making changes to _repack_model.py"""
69+
)
70+
def test_repack_with_dependencies(tmp):
71+
model_name = "xg-boost-model"
72+
fake_model_path = os.path.join(tmp, model_name)
73+
74+
# create a fake model
75+
open(fake_model_path, "w")
76+
77+
# create model.tar.gz
78+
model_tar_name = "model-%s.tar.gz" % time.time()
79+
model_tar_location = os.path.join(tmp, model_tar_name)
80+
with tarfile.open(model_tar_location, mode="w:gz") as t:
81+
t.add(fake_model_path, arcname=model_name)
82+
83+
# move model.tar.gz to /opt/ml/input/data/training
84+
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
85+
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))
86+
87+
# create files that will be added to model.tar.gz
88+
create_file_tree(
89+
"/opt/ml/code",
90+
["inference.py", "dependencies/a", "bb", "dependencies/some/dir/b"],
91+
)
92+
93+
# repack
94+
_repack_model.repack(
95+
inference_script="inference.py",
96+
model_archive=model_tar_name,
97+
dependencies=["dependencies/a", "bb", "dependencies/some/dir"],
98+
)
99+
100+
# /opt/ml/model should now have the original model and the inference script
101+
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
102+
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
103+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a"))
104+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb"))
105+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b"))
106+
107+
108+
@pytest.mark.skip(
109+
reason="""This test operates on the root file system
110+
and will likely fail due to permission errors.
111+
Temporarily remove this skip decorator and run
112+
the test after making changes to _repack_model.py"""
113+
)
114+
def test_repack_with_source_dir_and_dependencies(tmp):
115+
model_name = "xg-boost-model"
116+
fake_model_path = os.path.join(tmp, model_name)
117+
118+
# create a fake model
119+
open(fake_model_path, "w")
120+
121+
# create model.tar.gz
122+
model_tar_name = "model-%s.tar.gz" % time.time()
123+
model_tar_location = os.path.join(tmp, model_tar_name)
124+
with tarfile.open(model_tar_location, mode="w:gz") as t:
125+
t.add(fake_model_path, arcname=model_name)
126+
127+
# move model.tar.gz to /opt/ml/input/data/training
128+
Path("/opt/ml/input/data/training").mkdir(parents=True, exist_ok=True)
129+
shutil.move(model_tar_location, os.path.join("/opt/ml/input/data/training", model_tar_name))
130+
131+
# create files that will be added to model.tar.gz
132+
create_file_tree(
133+
"/opt/ml/code",
134+
[
135+
"inference.py",
136+
"dependencies/a",
137+
"bb",
138+
"dependencies/some/dir/b",
139+
"sourcedir/foo.py",
140+
"sourcedir/some/dir/a",
141+
],
142+
)
143+
144+
# repack
145+
_repack_model.repack(
146+
inference_script="inference.py",
147+
model_archive=model_tar_name,
148+
dependencies=["dependencies/a", "bb", "dependencies/some/dir"],
149+
source_dir="sourcedir",
150+
)
151+
152+
# /opt/ml/model should now have the original model and the inference script
153+
assert os.path.exists(os.path.join("/opt/ml/model", model_name))
154+
assert os.path.exists(os.path.join("/opt/ml/model/code", "inference.py"))
155+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "a"))
156+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib", "bb"))
157+
assert os.path.exists(os.path.join("/opt/ml/model/code/lib/dir", "b"))
158+
assert os.path.exists(os.path.join("/opt/ml/model/code/", "foo.py"))
159+
assert os.path.exists(os.path.join("/opt/ml/model/code/some/dir", "a"))
160+
161+
162+
def create_file_tree(root, tree):
163+
for file in tree:
164+
try:
165+
os.makedirs(os.path.join(root, os.path.dirname(file)))
166+
except: # noqa: E722 Using bare except because p2/3 incompatibility issues.
167+
pass
168+
with open(os.path.join(root, file), "a") as f:
169+
f.write(file)
170+
171+
172+
@pytest.fixture()
173+
def tmp(tmpdir):
174+
yield str(tmpdir)

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def test_register_model_sip(estimator, model_metrics):
366366

367367
def test_register_model_with_model_repack_with_estimator(estimator, model_metrics):
368368
model_data = f"s3://{BUCKET}/model.tar.gz"
369+
dummy_requirements = f"{DATA_DIR}/dummy_requirements.txt"
369370
register_model = RegisterModel(
370371
name="RegisterModelStep",
371372
estimator=estimator,
@@ -379,6 +380,7 @@ def test_register_model_with_model_repack_with_estimator(estimator, model_metric
379380
approval_status="Approved",
380381
description="description",
381382
entry_point=f"{DATA_DIR}/dummy_script.py",
383+
dependencies=[dummy_requirements],
382384
depends_on=["TestStep"],
383385
tags=[{"Key": "myKey", "Value": "myValue"}],
384386
)
@@ -405,6 +407,7 @@ def test_register_model_with_model_repack_with_estimator(estimator, model_metric
405407
},
406408
"HyperParameters": {
407409
"inference_script": '"dummy_script.py"',
410+
"dependencies": f'"{dummy_requirements}"',
408411
"model_archive": '"model.tar.gz"',
409412
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
410413
BUCKET, repacker_job_name.replace('"', "")
@@ -413,6 +416,7 @@ def test_register_model_with_model_repack_with_estimator(estimator, model_metric
413416
"sagemaker_container_log_level": "20",
414417
"sagemaker_job_name": repacker_job_name,
415418
"sagemaker_region": f'"{REGION}"',
419+
"source_dir": "null",
416420
},
417421
"InputDataConfig": [
418422
{
@@ -528,6 +532,8 @@ def test_register_model_with_model_repack_with_model(model, model_metrics):
528532
"sagemaker_container_log_level": "20",
529533
"sagemaker_job_name": repacker_job_name,
530534
"sagemaker_region": f'"{REGION}"',
535+
"dependencies": "null",
536+
"source_dir": "null",
531537
},
532538
"InputDataConfig": [
533539
{
@@ -631,6 +637,7 @@ def test_register_model_with_model_repack_with_pipeline_model(pipeline_model, mo
631637
"S3OutputPath": f"s3://{BUCKET}/",
632638
},
633639
"HyperParameters": {
640+
"dependencies": "null",
634641
"inference_script": '"dummy_script.py"',
635642
"model_archive": '"model.tar.gz"',
636643
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
@@ -640,6 +647,7 @@ def test_register_model_with_model_repack_with_pipeline_model(pipeline_model, mo
640647
"sagemaker_container_log_level": "20",
641648
"sagemaker_job_name": repacker_job_name,
642649
"sagemaker_region": f'"{REGION}"',
650+
"source_dir": "null",
643651
},
644652
"InputDataConfig": [
645653
{

0 commit comments

Comments
 (0)