Skip to content

Commit 87be0fc

Browse files
committed
fix black-format
1 parent e1b1d8d commit 87be0fc

File tree

3 files changed

+127
-9
lines changed

3 files changed

+127
-9
lines changed

src/sagemaker/fw_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -808,10 +808,10 @@ def validate_pytorch_distribution(
808808
# We need to validate only for PyTorch framework
809809
return
810810
if "pytorchddp" in distribution:
811-
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
812-
if not pytorch_ddp_enabled:
813-
# Distribution strategy other than pytorchddp is selected
814-
return
811+
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
812+
if not pytorch_ddp_enabled:
813+
# Distribution strategy other than pytorchddp is selected
814+
return
815815

816816
err_msg = ""
817817
if not image_uri:

tests/data/_repack_model.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
38+
"""Repack custom dependencies and code into an existing model TAR archive
39+
40+
Args:
41+
inference_script (str): The path to the custom entry point.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
43+
dependencies (str): A space-delimited string of paths to custom dependencies.
44+
source_dir (str): The path to a custom source directory.
45+
"""
46+
47+
# the data directory contains a model archive generated by a previous training job
48+
data_directory = "/opt/ml/input/data/training"
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
50+
51+
# create a temporary directory
52+
with tempfile.TemporaryDirectory() as tmp:
53+
local_path = os.path.join(tmp, "local.tar.gz")
54+
# copy the previous training job's model archive to the temporary directory
55+
shutil.copy2(model_path, local_path)
56+
src_dir = os.path.join(tmp, "src")
57+
# create the "code" directory which will contain the inference script
58+
code_dir = os.path.join(src_dir, "code")
59+
os.makedirs(code_dir)
60+
# extract the contents of the previous training job's model archive to the "src"
61+
# directory of this training job
62+
with tarfile.open(name=local_path, mode="r:gz") as tf:
63+
tf.extractall(path=src_dir)
64+
65+
if source_dir:
66+
# copy /opt/ml/code to code/
67+
if os.path.exists(code_dir):
68+
shutil.rmtree(code_dir)
69+
shutil.copytree("/opt/ml/code", code_dir)
70+
else:
71+
# copy the custom inference script to code/
72+
entry_point = os.path.join("/opt/ml/code", inference_script)
73+
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))
74+
75+
# copy any dependencies to code/lib/
76+
if dependencies:
77+
for dependency in dependencies.split(" "):
78+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
79+
lib_dir = os.path.join(code_dir, "lib")
80+
if not os.path.exists(lib_dir):
81+
os.mkdir(lib_dir)
82+
if os.path.isfile(actual_dependency_path):
83+
shutil.copy2(actual_dependency_path, lib_dir)
84+
else:
85+
if os.path.exists(lib_dir):
86+
shutil.rmtree(lib_dir)
87+
# a directory is in the dependencies. we have to copy
88+
# all of /opt/ml/code into the lib dir because the original directory
89+
# was flattened by the SDK training job upload..
90+
shutil.copytree("/opt/ml/code", lib_dir)
91+
break
92+
93+
# copy the "src" dir, which includes the previous training job's model and the
94+
# custom inference script, to the output of this training job
95+
copy_tree(src_dir, "/opt/ml/model")
96+
97+
98+
if __name__ == "__main__": # pragma: no cover
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument("--inference_script", type=str, default="inference.py")
101+
parser.add_argument("--dependencies", type=str, default=None)
102+
parser.add_argument("--source_dir", type=str, default=None)
103+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
104+
args, extra = parser.parse_known_args()
105+
repack(
106+
inference_script=args.inference_script,
107+
dependencies=args.dependencies,
108+
source_dir=args.source_dir,
109+
model_archive=args.model_archive,
110+
)

tests/unit/test_fw_utils.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def test_validate_pytorchddp_not_raises():
858858
py_version="py3",
859859
image_uri="custom-container",
860860
)
861-
# Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
861+
# Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
862862
pytorchddp_disabled = {"pytorchddp": {"enabled": False}}
863863
fw_utils.validate_pytorch_distribution(
864864
distribution=pytorchddp_disabled,
@@ -867,9 +867,17 @@ def test_validate_pytorchddp_not_raises():
867867
py_version="py3",
868868
image_uri="custom-container",
869869
)
870-
# Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
870+
# Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
871871
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
872-
pytorchddp_supported_fw_versions = ["1.10","1.10.0", "1.10.2","1.11","1.11.0","1.12","1.12.0"]
872+
pytorchddp_supported_fw_versions = [
873+
"1.10",
874+
"1.10.0",
875+
"1.10.2",
876+
"1.11",
877+
"1.11.0",
878+
"1.12",
879+
"1.12.0",
880+
]
873881
for framework_version in pytorchddp_supported_fw_versions:
874882
fw_utils.validate_pytorch_distribution(
875883
distribution=pytorchddp_enabled,
@@ -892,12 +900,12 @@ def test_validate_pytorchddp_raises():
892900
image_uri=None,
893901
)
894902

895-
# Case 2: Unsupported Py version
903+
# Case 2: Unsupported Py version
896904
with pytest.raises(ValueError):
897905
fw_utils.validate_pytorch_distribution(
898906
distribution=pytorchddp_enabled,
899907
framework_name="pytorch",
900908
framework_version="1.10",
901909
py_version="py2",
902910
image_uri=None,
903-
)
911+
)

0 commit comments

Comments
 (0)