Skip to content

Commit 44d29b9

Browse files
committed
Allow users to customize trial component display names for pipeline launched jobs
1 parent a5464a2 commit 44d29b9

11 files changed

+363
-27
lines changed

doc/amazon_sagemaker_model_building_pipeline.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ There are a number of properties for a pipeline execution that can only be resol
741741
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_EXECUTION_ARN`: The execution ARN for an execution.
742742
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_NAME`: The name of the pipeline.
743743
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_ARN`: The ARN of the pipeline.
744+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.TRAINING_JOB_NAME`: The name of the training job launched by the training step.
745+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PROCESSING_JOB_NAME`: The name of the processing job launched by the processing step.
744746
745747
You can use these execution variables as you see fit. The following example uses the :code:`START_DATETIME` execution variable to construct a processing output path:
746748

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Execution Variables
5252
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariable
5353

5454
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariables
55-
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN
55+
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN, TRAINING_JOB_NAME, PROCESSING_JOB_NAME
5656

5757
Functions
5858
---------

src/sagemaker/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,12 @@ def fit(
10001000
* If both `ExperimentName` and `TrialName` are not supplied the trial component
10011001
will be unassociated.
10021002
* `TrialComponentDisplayName` is used for display in Studio.
1003+
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
1004+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
1005+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
1006+
Returns:
1007+
None or pipeline step arguments in case the Estimator instance is built with
1008+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
10031009
"""
10041010
self._prepare_for_training(job_name=job_name)
10051011

src/sagemaker/processing.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,14 @@ def run(
173173
* If both `ExperimentName` and `TrialName` are not supplied the trial component
174174
will be unassociated.
175175
* `TrialComponentDisplayName` is used for display in Studio.
176+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
177+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
178+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
176179
kms_key (str): The ARN of the KMS key that is used to encrypt the
177180
user code file (default: None).
178-
181+
Returns:
182+
None or pipeline step arguments in case the Processor instance is built with
183+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
179184
Raises:
180185
ValueError: if ``logs`` is True but ``wait`` is False.
181186
"""
@@ -543,8 +548,14 @@ def run(
543548
* If both `ExperimentName` and `TrialName` are not supplied the trial component
544549
will be unassociated.
545550
* `TrialComponentDisplayName` is used for display in Studio.
551+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
552+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
553+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
546554
kms_key (str): The ARN of the KMS key that is used to encrypt the
547555
user code file (default: None).
556+
Returns:
557+
None or pipeline step arguments in case the Processor instance is built with
558+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
548559
"""
549560
normalized_inputs, normalized_outputs = self._normalize_args(
550561
job_name=job_name,
@@ -1601,8 +1612,14 @@ def run( # type: ignore[override]
16011612
* If both `ExperimentName` and `TrialName` are not supplied the trial component
16021613
will be unassociated.
16031614
* `TrialComponentDisplayName` is used for display in Studio.
1615+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
1616+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
1617+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
16041618
kms_key (str): The ARN of the KMS key that is used to encrypt the
16051619
user code file (default: None).
1620+
Returns:
1621+
None or pipeline step arguments in case the Processor instance is built with
1622+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
16061623
"""
16071624
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
16081625
code, source_dir, dependencies, git_config, job_name, inputs

src/sagemaker/transformer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ def transform(
186186
* If both `ExperimentName` and `TrialName` are not supplied the trial component
187187
will be unassociated.
188188
* `TrialComponentDisplayName` is used for display in Studio.
189+
* Both `ExperimentName` and `TrialName` will be ignored if the Transformer instance
190+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
191+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
189192
model_client_config (dict[str, str]): Model configuration.
190193
Dictionary contains two optional keys,
191194
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
@@ -194,6 +197,11 @@ def transform(
194197
(default: ``True``).
195198
logs (bool): Whether to show the logs produced by the job.
196199
Only meaningful when wait is ``True`` (default: ``True``).
200+
kms_key (str): The ARN of the KMS key that is used to encrypt the
201+
user code file (default: None).
202+
Returns:
203+
None or pipeline step arguments in case the Transformer instance is built with
204+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
197205
"""
198206
local_mode = self.sagemaker_session.local_mode
199207
if not local_mode and not is_pipeline_variable(data) and not data.startswith("s3://"):

src/sagemaker/workflow/execution_variables.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class ExecutionVariables:
5858
- ExecutionVariables.PIPELINE_ARN
5959
- ExecutionVariables.PIPELINE_EXECUTION_ID
6060
- ExecutionVariables.PIPELINE_EXECUTION_ARN
61+
- ExecutionVariables.TRAINING_JOB_NAME
62+
- ExecutionVariables.PROCESSING_JOB_NAME
6163
"""
6264

6365
START_DATETIME = ExecutionVariable("StartDateTime")
@@ -66,3 +68,5 @@ class ExecutionVariables:
6668
PIPELINE_ARN = ExecutionVariable("PipelineArn")
6769
PIPELINE_EXECUTION_ID = ExecutionVariable("PipelineExecutionId")
6870
PIPELINE_EXECUTION_ARN = ExecutionVariable("PipelineExecutionArn")
71+
TRAINING_JOB_NAME = ExecutionVariable("TrainingJobName")
72+
PROCESSING_JOB_NAME = ExecutionVariable("ProcessingJobName")

src/sagemaker/workflow/steps.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,16 @@ def arguments(self) -> RequestType:
429429
request_dict["HyperParameters"].pop("sagemaker_job_name", None)
430430

431431
request_dict.pop("TrainingJobName", None)
432-
request_dict.pop("ExperimentConfig", None)
432+
433+
# only keep the trial component name
434+
if request_dict.get("ExperimentConfig", {}).get("TrialComponentDisplayName"):
435+
request_dict["ExperimentConfig"] = {
436+
"TrialComponentDisplayName": request_dict["ExperimentConfig"][
437+
"TrialComponentDisplayName"
438+
]
439+
}
440+
else:
441+
request_dict.pop("ExperimentConfig", None)
433442

434443
return request_dict
435444

@@ -660,7 +669,17 @@ def arguments(self) -> RequestType:
660669
)
661670

662671
request_dict.pop("TransformJobName", None)
663-
request_dict.pop("ExperimentConfig", None)
672+
673+
# only keep the trial component name
674+
if request_dict.get("ExperimentConfig", {}).get("TrialComponentDisplayName"):
675+
request_dict["ExperimentConfig"] = {
676+
"TrialComponentDisplayName": request_dict["ExperimentConfig"][
677+
"TrialComponentDisplayName"
678+
]
679+
}
680+
else:
681+
request_dict.pop("ExperimentConfig", None)
682+
664683
return request_dict
665684

666685
@property
@@ -808,7 +827,17 @@ def arguments(self) -> RequestType:
808827
request_dict = self.processor.sagemaker_session._get_process_request(**process_args)
809828

810829
request_dict.pop("ProcessingJobName", None)
811-
request_dict.pop("ExperimentConfig", None)
830+
831+
# only keep the trial component name
832+
if request_dict.get("ExperimentConfig", {}).get("TrialComponentDisplayName"):
833+
request_dict["ExperimentConfig"] = {
834+
"TrialComponentDisplayName": request_dict["ExperimentConfig"][
835+
"TrialComponentDisplayName"
836+
]
837+
}
838+
else:
839+
request_dict.pop("ExperimentConfig", None)
840+
812841
return request_dict
813842

814843
@property

tests/data/_repack_model.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
38+
"""Repack custom dependencies and code into an existing model TAR archive
39+
40+
Args:
41+
inference_script (str): The path to the custom entry point.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
43+
dependencies (str): A space-delimited string of paths to custom dependencies.
44+
source_dir (str): The path to a custom source directory.
45+
"""
46+
47+
# the data directory contains a model archive generated by a previous training job
48+
data_directory = "/opt/ml/input/data/training"
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
50+
51+
# create a temporary directory
52+
with tempfile.TemporaryDirectory() as tmp:
53+
local_path = os.path.join(tmp, "local.tar.gz")
54+
# copy the previous training job's model archive to the temporary directory
55+
shutil.copy2(model_path, local_path)
56+
src_dir = os.path.join(tmp, "src")
57+
# create the "code" directory which will contain the inference script
58+
code_dir = os.path.join(src_dir, "code")
59+
os.makedirs(code_dir)
60+
# extract the contents of the previous training job's model archive to the "src"
61+
# directory of this training job
62+
with tarfile.open(name=local_path, mode="r:gz") as tf:
63+
tf.extractall(path=src_dir)
64+
65+
if source_dir:
66+
# copy /opt/ml/code to code/
67+
if os.path.exists(code_dir):
68+
shutil.rmtree(code_dir)
69+
shutil.copytree("/opt/ml/code", code_dir)
70+
else:
71+
# copy the custom inference script to code/
72+
entry_point = os.path.join("/opt/ml/code", inference_script)
73+
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))
74+
75+
# copy any dependencies to code/lib/
76+
if dependencies:
77+
for dependency in dependencies.split(" "):
78+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
79+
lib_dir = os.path.join(code_dir, "lib")
80+
if not os.path.exists(lib_dir):
81+
os.mkdir(lib_dir)
82+
if os.path.isfile(actual_dependency_path):
83+
shutil.copy2(actual_dependency_path, lib_dir)
84+
else:
85+
if os.path.exists(lib_dir):
86+
shutil.rmtree(lib_dir)
87+
# a directory is in the dependencies. we have to copy
88+
# all of /opt/ml/code into the lib dir because the original directory
89+
# was flattened by the SDK training job upload..
90+
shutil.copytree("/opt/ml/code", lib_dir)
91+
break
92+
93+
# copy the "src" dir, which includes the previous training job's model and the
94+
# custom inference script, to the output of this training job
95+
copy_tree(src_dir, "/opt/ml/model")
96+
97+
98+
if __name__ == "__main__": # pragma: no cover
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument("--inference_script", type=str, default="inference.py")
101+
parser.add_argument("--dependencies", type=str, default=None)
102+
parser.add_argument("--source_dir", type=str, default=None)
103+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
104+
args, extra = parser.parse_known_args()
105+
repack(
106+
inference_script=args.inference_script,
107+
dependencies=args.dependencies,
108+
source_dir=args.source_dir,
109+
model_archive=args.model_archive,
110+
)

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import pytest
1919
import warnings
2020

21+
from copy import deepcopy
22+
2123
from sagemaker.estimator import Estimator
2224
from sagemaker.parameter import IntegerParameter
2325
from sagemaker.transformer import Transformer
@@ -268,7 +270,34 @@ def network_config():
268270
)
269271

270272

271-
def test_processing_step_with_processor(pipeline_session, processing_input):
273+
@pytest.mark.parametrize(
274+
"experiment_config, expected_experiment_config",
275+
[
276+
(
277+
{
278+
"ExperimentName": "experiment-name",
279+
"TrialName": "trial-name",
280+
"TrialComponentDisplayName": "display-name",
281+
},
282+
{"TrialComponentDisplayName": "display-name"},
283+
),
284+
(
285+
{"TrialComponentDisplayName": "display-name"},
286+
{"TrialComponentDisplayName": "display-name"},
287+
),
288+
(
289+
{
290+
"ExperimentName": "experiment-name",
291+
"TrialName": "trial-name",
292+
},
293+
None,
294+
),
295+
(None, None),
296+
],
297+
)
298+
def test_processing_step_with_processor(
299+
pipeline_session, processing_input, experiment_config, expected_experiment_config
300+
):
272301
custom_step1 = CustomStep("TestStep")
273302
custom_step2 = CustomStep("SecondTestStep")
274303
processor = Processor(
@@ -280,7 +309,7 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
280309
)
281310

282311
with warnings.catch_warnings(record=True) as w:
283-
step_args = processor.run(inputs=processing_input)
312+
step_args = processor.run(inputs=processing_input, experiment_config=experiment_config)
284313
assert len(w) == 1
285314
assert issubclass(w[-1].category, UserWarning)
286315
assert "Running within a PipelineSession" in str(w[-1].message)
@@ -307,13 +336,21 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
307336
steps=[step, custom_step1, custom_step2],
308337
sagemaker_session=pipeline_session,
309338
)
339+
340+
expected_step_arguments = deepcopy(step_args.args)
341+
if expected_experiment_config is None:
342+
expected_step_arguments.pop("ExperimentConfig", None)
343+
else:
344+
expected_step_arguments["ExperimentConfig"] = expected_experiment_config
345+
del expected_step_arguments["ProcessingJobName"]
346+
310347
assert json.loads(pipeline.definition())["Steps"][0] == {
311348
"Name": "MyProcessingStep",
312349
"Description": "ProcessingStep description",
313350
"DisplayName": "MyProcessingStep",
314351
"Type": "Processing",
315352
"DependsOn": ["TestStep", "SecondTestStep"],
316-
"Arguments": step_args.args,
353+
"Arguments": expected_step_arguments,
317354
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
318355
"PropertyFiles": [
319356
{

0 commit comments

Comments
 (0)