Skip to content

Commit bf8bcd8

Browse files
committed
feature: support displayName and description for pipeline steps
1 parent 80d5fb7 commit bf8bcd8

File tree

5 files changed

+206
-2
lines changed

5 files changed

+206
-2
lines changed

src/sagemaker/workflow/lambda_step.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def __init__(
105105
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.LambdaStep`
106106
depends on
107107
"""
108-
super(LambdaStep, self).__init__(name,display_name, description, StepTypeEnum.LAMBDA, depends_on)
108+
super(LambdaStep, self).__init__(
109+
name, display_name, description, StepTypeEnum.LAMBDA, depends_on
110+
)
109111
self.lambda_func = lambda_func
110112
self.outputs = outputs if outputs is not None else []
111113
self.cache_config = cache_config
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
if __name__ == "__main__":
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument("--inference_script", type=str, default="inference.py")
40+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
41+
args = parser.parse_args()
42+
43+
# the data directory contains a model archive generated by a previous training job
44+
data_directory = "/opt/ml/input/data/training"
45+
model_path = os.path.join(data_directory, args.model_archive)
46+
47+
# create a temporary directory
48+
with tempfile.TemporaryDirectory() as tmp:
49+
local_path = os.path.join(tmp, "local.tar.gz")
50+
# copy the previous training job's model archive to the temporary directory
51+
shutil.copy2(model_path, local_path)
52+
src_dir = os.path.join(tmp, "src")
53+
# create the "code" directory which will contain the inference script
54+
os.makedirs(os.path.join(src_dir, "code"))
55+
# extract the contents of the previous training job's model archive to the "src"
56+
# directory of this training job
57+
with tarfile.open(name=local_path, mode="r:gz") as tf:
58+
tf.extractall(path=src_dir)
59+
60+
# generate a path to the custom inference script
61+
entry_point = os.path.join("/opt/ml/code", args.inference_script)
62+
# copy the custom inference script to the "src" dir
63+
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
64+
65+
# copy the "src" dir, which includes the previous training job's model and the
66+
# custom inference script, to the output of this training job
67+
copy_tree(src_dir, "/opt/ml/model")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
if __name__ == "__main__":
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument("--inference_script", type=str, default="inference.py")
40+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
41+
args = parser.parse_args()
42+
43+
# the data directory contains a model archive generated by a previous training job
44+
data_directory = "/opt/ml/input/data/training"
45+
model_path = os.path.join(data_directory, args.model_archive)
46+
47+
# create a temporary directory
48+
with tempfile.TemporaryDirectory() as tmp:
49+
local_path = os.path.join(tmp, "local.tar.gz")
50+
# copy the previous training job's model archive to the temporary directory
51+
shutil.copy2(model_path, local_path)
52+
src_dir = os.path.join(tmp, "src")
53+
# create the "code" directory which will contain the inference script
54+
os.makedirs(os.path.join(src_dir, "code"))
55+
# extract the contents of the previous training job's model archive to the "src"
56+
# directory of this training job
57+
with tarfile.open(name=local_path, mode="r:gz") as tf:
58+
tf.extractall(path=src_dir)
59+
60+
# generate a path to the custom inference script
61+
entry_point = os.path.join("/opt/ml/code", args.inference_script)
62+
# copy the custom inference script to the "src" dir
63+
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
64+
65+
# copy the "src" dir, which includes the previous training job's model and the
66+
# custom inference script, to the output of this training job
67+
copy_tree(src_dir, "/opt/ml/model")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
if __name__ == "__main__":
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument("--inference_script", type=str, default="inference.py")
40+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
41+
args = parser.parse_args()
42+
43+
# the data directory contains a model archive generated by a previous training job
44+
data_directory = "/opt/ml/input/data/training"
45+
model_path = os.path.join(data_directory, args.model_archive)
46+
47+
# create a temporary directory
48+
with tempfile.TemporaryDirectory() as tmp:
49+
local_path = os.path.join(tmp, "local.tar.gz")
50+
# copy the previous training job's model archive to the temporary directory
51+
shutil.copy2(model_path, local_path)
52+
src_dir = os.path.join(tmp, "src")
53+
# create the "code" directory which will contain the inference script
54+
os.makedirs(os.path.join(src_dir, "code"))
55+
# extract the contents of the previous training job's model archive to the "src"
56+
# directory of this training job
57+
with tarfile.open(name=local_path, mode="r:gz") as tf:
58+
tf.extractall(path=src_dir)
59+
60+
# generate a path to the custom inference script
61+
entry_point = os.path.join("/opt/ml/code", args.inference_script)
62+
# copy the custom inference script to the "src" dir
63+
shutil.copy2(entry_point, os.path.join(src_dir, "code", args.inference_script))
64+
65+
# copy the "src" dir, which includes the previous training job's model and the
66+
# custom inference script, to the output of this training job
67+
copy_tree(src_dir, "/opt/ml/model")

tests/unit/sagemaker/workflow/test_lambda_step.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def test_lambda_step(sagemaker_session):
4848
function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda",
4949
session=sagemaker_session,
5050
),
51-
display_name="MyLambdaStep", description="MyLambdaStepDescription",
51+
display_name="MyLambdaStep",
52+
description="MyLambdaStepDescription",
5253
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
5354
outputs=[outputParam1, outputParam2],
5455
)

0 commit comments

Comments
 (0)