Skip to content

Commit 1382f15

Browse files
Merge pull request #10 from aws/master
2 parents 437c39c + cd22a6e commit 1382f15

12 files changed

+1069
-32
lines changed

CHANGELOG.md

+16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
# Changelog
22

3+
## v2.63.0 (2021-10-13)
4+
5+
### Features
6+
7+
* support configurable retry for pipeline steps
8+
9+
## v2.62.0 (2021-10-12)
10+
11+
### Features
12+
13+
* Hugging Face Transformers 4.10 for Pt1.8/TF2.4 & Transformers 4.11 for PT1.9&TF2.5
14+
15+
### Bug Fixes and Other Changes
16+
17+
* repack_model script used in pipelines to support source_dir and dependencies
18+
319
## v2.61.0 (2021-10-11)
420

521
### Features

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.61.1.dev0
1+
2.63.1.dev0

src/sagemaker/workflow/_repack_model.py

+51-11
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@
3434
from distutils.dir_util import copy_tree
3535

3636

37-
if __name__ == "__main__":
38-
parser = argparse.ArgumentParser()
39-
parser.add_argument("--inference_script", type=str, default="inference.py")
40-
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
41-
args = parser.parse_args()
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None):
38+
"""Repack custom dependencies and code into an existing model TAR archive
39+
40+
Args:
41+
inference_script (str): The path to the custom entry point.
42+
model_archive (str): The name of the model TAR archive.
43+
dependencies (str): A space-delimited string of paths to custom dependencies.
44+
source_dir (str): The path to a custom source directory.
45+
"""
4246

4347
# the data directory contains a model archive generated by a previous training job
4448
data_directory = "/opt/ml/input/data/training"
45-
model_path = os.path.join(data_directory, args.model_archive)
49+
model_path = os.path.join(data_directory, model_archive)
4650

4751
# create a temporary directory
4852
with tempfile.TemporaryDirectory() as tmp:
@@ -51,17 +55,53 @@
5155
shutil.copy2(model_path, local_path)
5256
src_dir = os.path.join(tmp, "src")
5357
# create the "code" directory which will contain the inference script
54-
os.makedirs(os.path.join(src_dir, "code"))
58+
code_dir = os.path.join(src_dir, "code")
59+
os.makedirs(code_dir)
5560
# extract the contents of the previous training job's model archive to the "src"
5661
# directory of this training job
5762
with tarfile.open(name=local_path, mode="r:gz") as tf:
5863
tf.extractall(path=src_dir)
5964

60-
# generate a path to the custom inference script
61-
entry_point = os.path.join("/opt/ml/code", args.inference_script)
62-
# copy the custom inference script to the "src" dir
63-
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
65+
# copy the custom inference script to code/
66+
entry_point = os.path.join("/opt/ml/code", inference_script)
67+
shutil.copy2(entry_point, os.path.join(src_dir, "code", inference_script))
68+
69+
# copy source_dir to code/
70+
if source_dir:
71+
if os.path.exists(code_dir):
72+
shutil.rmtree(code_dir)
73+
shutil.copytree(source_dir, code_dir)
74+
75+
# copy any dependencies to code/lib/
76+
if dependencies:
77+
for dependency in dependencies.split(" "):
78+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
79+
lib_dir = os.path.join(code_dir, "lib")
80+
if not os.path.exists(lib_dir):
81+
os.mkdir(lib_dir)
82+
if os.path.isdir(actual_dependency_path):
83+
shutil.copytree(
84+
actual_dependency_path,
85+
os.path.join(lib_dir, os.path.basename(actual_dependency_path)),
86+
)
87+
else:
88+
shutil.copy2(actual_dependency_path, lib_dir)
6489

6590
# copy the "src" dir, which includes the previous training job's model and the
6691
# custom inference script, to the output of this training job
6792
copy_tree(src_dir, "/opt/ml/model")
93+
94+
95+
if __name__ == "__main__":
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--inference_script", type=str, default="inference.py")
98+
parser.add_argument("--dependencies", type=str, default=None)
99+
parser.add_argument("--source_dir", type=str, default=None)
100+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
101+
args, extra = parser.parse_known_args()
102+
repack(
103+
inference_script=args.inference_script,
104+
dependencies=args.dependencies,
105+
source_dir=args.source_dir,
106+
model_archive=args.model_archive,
107+
)

src/sagemaker/workflow/_utils.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@
2828
from sagemaker.sklearn.estimator import SKLearn
2929
from sagemaker.workflow.entities import RequestType
3030
from sagemaker.workflow.properties import Properties
31-
from sagemaker.session import get_create_model_package_request
32-
from sagemaker.session import get_model_package_args
31+
from sagemaker.session import get_create_model_package_request, get_model_package_args
3332
from sagemaker.workflow.steps import (
3433
StepTypeEnum,
3534
TrainingStep,
3635
Step,
36+
ConfigurableRetryStep,
3737
)
38+
from sagemaker.workflow.retry import RetryPolicy
3839

3940
FRAMEWORK_VERSION = "0.23-1"
4041
INSTANCE_TYPE = "ml.m5.large"
@@ -60,6 +61,7 @@ def __init__(
6061
source_dir: str = None,
6162
dependencies: List = None,
6263
depends_on: Union[List[str], List[Step]] = None,
64+
retry_policies: List[RetryPolicy] = None,
6365
subnets=None,
6466
security_group_ids=None,
6567
**kwargs,
@@ -126,6 +128,7 @@ def __init__(
126128
This is not supported with "local code" in Local Mode.
127129
depends_on (List[str] or List[Step]): A list of step names or instances
128130
this step depends on
131+
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
129132
subnets (list[str]): List of subnet ids. If not specified, the re-packing
130133
job will be created without VPC config.
131134
security_group_ids (list[str]): List of security group ids. If not
@@ -145,6 +148,11 @@ def __init__(
145148
self._source_dir = source_dir
146149
self._dependencies = dependencies
147150

151+
# convert dependencies array into space-delimited string
152+
dependencies_hyperparameter = None
153+
if self._dependencies:
154+
dependencies_hyperparameter = " ".join(self._dependencies)
155+
148156
# the real estimator and inputs
149157
repacker = SKLearn(
150158
framework_version=FRAMEWORK_VERSION,
@@ -157,6 +165,8 @@ def __init__(
157165
hyperparameters={
158166
"inference_script": self._entry_point_basename,
159167
"model_archive": self._model_archive,
168+
"dependencies": dependencies_hyperparameter,
169+
"source_dir": self._source_dir,
160170
},
161171
subnets=subnets,
162172
security_group_ids=security_group_ids,
@@ -171,6 +181,7 @@ def __init__(
171181
display_name=display_name,
172182
description=description,
173183
depends_on=depends_on,
184+
retry_policies=retry_policies,
174185
estimator=repacker,
175186
inputs=inputs,
176187
)
@@ -252,7 +263,7 @@ def properties(self):
252263
return self._properties
253264

254265

255-
class _RegisterModelStep(Step):
266+
class _RegisterModelStep(ConfigurableRetryStep):
256267
"""Register model step in workflow that creates a model package.
257268
258269
Attributes:
@@ -295,6 +306,7 @@ def __init__(
295306
display_name: str = None,
296307
description=None,
297308
depends_on: Union[List[str], List[Step]] = None,
309+
retry_policies: List[RetryPolicy] = None,
298310
tags=None,
299311
container_def_list=None,
300312
**kwargs,
@@ -332,10 +344,11 @@ def __init__(
332344
description (str): Model Package description (default: None).
333345
depends_on (List[str] or List[Step]): A list of step names or instances
334346
this step depends on
347+
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
335348
**kwargs: additional arguments to `create_model`.
336349
"""
337350
super(_RegisterModelStep, self).__init__(
338-
name, display_name, description, StepTypeEnum.REGISTER_MODEL, depends_on
351+
name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies
339352
)
340353
self.estimator = estimator
341354
self.model_data = model_data

0 commit comments

Comments
 (0)