Skip to content

Commit 02fdf1b

Browse files
navaj0Zhankuil
andauthored
Allow users to customize trial component display names for pipeline launched jobs (#3230)
Co-authored-by: Zhankui Lu <[email protected]>
1 parent 284ddbe commit 02fdf1b

10 files changed

+238
-27
lines changed

doc/amazon_sagemaker_model_building_pipeline.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ There are a number of properties for a pipeline execution that can only be resol
741741
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_EXECUTION_ARN`: The execution ARN for an execution.
742742
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_NAME`: The name of the pipeline.
743743
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_ARN`: The ARN of the pipeline.
744+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.TRAINING_JOB_NAME`: The name of the training job launched by the training step.
745+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PROCESSING_JOB_NAME`: The name of the processing job launched by the processing step.
744746
745747
You can use these execution variables as you see fit. The following example uses the :code:`START_DATETIME` execution variable to construct a processing output path:
746748

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Execution Variables
5252
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariable
5353

5454
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariables
55-
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN
55+
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN, TRAINING_JOB_NAME, PROCESSING_JOB_NAME
5656

5757
Functions
5858
---------

src/sagemaker/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,12 @@ def fit(
10251025
* If both `ExperimentName` and `TrialName` are not supplied the trial component
10261026
will be unassociated.
10271027
* `TrialComponentDisplayName` is used for display in Studio.
1028+
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
1029+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
1030+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
1031+
Returns:
1032+
None or pipeline step arguments in case the Estimator instance is built with
1033+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
10281034
"""
10291035
self._prepare_for_training(job_name=job_name)
10301036

src/sagemaker/processing.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,14 @@ def run(
173173
* If both `ExperimentName` and `TrialName` are not supplied the trial component
174174
will be unassociated.
175175
* `TrialComponentDisplayName` is used for display in Studio.
176+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
177+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
178+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
176179
kms_key (str): The ARN of the KMS key that is used to encrypt the
177180
user code file (default: None).
178-
181+
Returns:
182+
None or pipeline step arguments in case the Processor instance is built with
183+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
179184
Raises:
180185
ValueError: if ``logs`` is True but ``wait`` is False.
181186
"""
@@ -543,8 +548,14 @@ def run(
543548
* If both `ExperimentName` and `TrialName` are not supplied the trial component
544549
will be unassociated.
545550
* `TrialComponentDisplayName` is used for display in Studio.
551+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
552+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
553+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
546554
kms_key (str): The ARN of the KMS key that is used to encrypt the
547555
user code file (default: None).
556+
Returns:
557+
None or pipeline step arguments in case the Processor instance is built with
558+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
548559
"""
549560
normalized_inputs, normalized_outputs = self._normalize_args(
550561
job_name=job_name,
@@ -1601,8 +1612,14 @@ def run( # type: ignore[override]
16011612
* If both `ExperimentName` and `TrialName` are not supplied the trial component
16021613
will be unassociated.
16031614
* `TrialComponentDisplayName` is used for display in Studio.
1615+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
1616+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
1617+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
16041618
kms_key (str): The ARN of the KMS key that is used to encrypt the
16051619
user code file (default: None).
1620+
Returns:
1621+
None or pipeline step arguments in case the Processor instance is built with
1622+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
16061623
"""
16071624
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
16081625
code, source_dir, dependencies, git_config, job_name, inputs

src/sagemaker/transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ def transform(
186186
* If both `ExperimentName` and `TrialName` are not supplied the trial component
187187
will be unassociated.
188188
* `TrialComponentDisplayName` is used for display in Studio.
189+
* Both `ExperimentName` and `TrialName` will be ignored if the Transformer instance
190+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
191+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
189192
model_client_config (dict[str, str]): Model configuration.
190193
Dictionary contains two optional keys,
191194
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
@@ -194,6 +197,9 @@ def transform(
194197
(default: ``True``).
195198
logs (bool): Whether to show the logs produced by the job.
196199
Only meaningful when wait is ``True`` (default: ``True``).
200+
Returns:
201+
None or pipeline step arguments in case the Transformer instance is built with
202+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
197203
"""
198204
local_mode = self.sagemaker_session.local_mode
199205
if not local_mode and not is_pipeline_variable(data) and not data.startswith("s3://"):

src/sagemaker/workflow/execution_variables.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class ExecutionVariables:
5858
- ExecutionVariables.PIPELINE_ARN
5959
- ExecutionVariables.PIPELINE_EXECUTION_ID
6060
- ExecutionVariables.PIPELINE_EXECUTION_ARN
61+
- ExecutionVariables.TRAINING_JOB_NAME
62+
- ExecutionVariables.PROCESSING_JOB_NAME
6163
"""
6264

6365
START_DATETIME = ExecutionVariable("StartDateTime")
@@ -66,3 +68,5 @@ class ExecutionVariables:
6668
PIPELINE_ARN = ExecutionVariable("PipelineArn")
6769
PIPELINE_EXECUTION_ID = ExecutionVariable("PipelineExecutionId")
6870
PIPELINE_EXECUTION_ARN = ExecutionVariable("PipelineExecutionArn")
71+
TRAINING_JOB_NAME = ExecutionVariable("TrainingJobName")
72+
PROCESSING_JOB_NAME = ExecutionVariable("ProcessingJobName")

src/sagemaker/workflow/steps.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ def _get_step_name_from_str(
223223
return step_map[str_input].steps[-1].name
224224
return str_input
225225

226+
@staticmethod
227+
def _trim_experiment_config(request_dict: Dict):
228+
"""For job steps, trim the experiment config to keep the trial component display name."""
229+
if request_dict.get("ExperimentConfig", {}).get("TrialComponentDisplayName"):
230+
request_dict["ExperimentConfig"] = {
231+
"TrialComponentDisplayName": request_dict["ExperimentConfig"][
232+
"TrialComponentDisplayName"
233+
]
234+
}
235+
else:
236+
request_dict.pop("ExperimentConfig", None)
237+
226238

227239
@attr.s
228240
class CacheConfig:
@@ -432,7 +444,7 @@ def arguments(self) -> RequestType:
432444
request_dict["HyperParameters"].pop("sagemaker_job_name", None)
433445

434446
request_dict.pop("TrainingJobName", None)
435-
request_dict.pop("ExperimentConfig", None)
447+
Step._trim_experiment_config(request_dict)
436448

437449
return request_dict
438450

@@ -663,7 +675,8 @@ def arguments(self) -> RequestType:
663675
)
664676

665677
request_dict.pop("TransformJobName", None)
666-
request_dict.pop("ExperimentConfig", None)
678+
Step._trim_experiment_config(request_dict)
679+
667680
return request_dict
668681

669682
@property
@@ -811,7 +824,8 @@ def arguments(self) -> RequestType:
811824
request_dict = self.processor.sagemaker_session._get_process_request(**process_args)
812825

813826
request_dict.pop("ProcessingJobName", None)
814-
request_dict.pop("ExperimentConfig", None)
827+
Step._trim_experiment_config(request_dict)
828+
815829
return request_dict
816830

817831
@property

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import pytest
1919
import warnings
2020

21+
from copy import deepcopy
22+
2123
from sagemaker.estimator import Estimator
2224
from sagemaker.parameter import IntegerParameter
2325
from sagemaker.transformer import Transformer
@@ -244,7 +246,34 @@ def network_config():
244246
)
245247

246248

247-
def test_processing_step_with_processor(pipeline_session, processing_input):
249+
@pytest.mark.parametrize(
250+
"experiment_config, expected_experiment_config",
251+
[
252+
(
253+
{
254+
"ExperimentName": "experiment-name",
255+
"TrialName": "trial-name",
256+
"TrialComponentDisplayName": "display-name",
257+
},
258+
{"TrialComponentDisplayName": "display-name"},
259+
),
260+
(
261+
{"TrialComponentDisplayName": "display-name"},
262+
{"TrialComponentDisplayName": "display-name"},
263+
),
264+
(
265+
{
266+
"ExperimentName": "experiment-name",
267+
"TrialName": "trial-name",
268+
},
269+
None,
270+
),
271+
(None, None),
272+
],
273+
)
274+
def test_processing_step_with_processor(
275+
pipeline_session, processing_input, experiment_config, expected_experiment_config
276+
):
248277
custom_step1 = CustomStep("TestStep")
249278
custom_step2 = CustomStep("SecondTestStep")
250279
processor = Processor(
@@ -256,7 +285,7 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
256285
)
257286

258287
with warnings.catch_warnings(record=True) as w:
259-
step_args = processor.run(inputs=processing_input)
288+
step_args = processor.run(inputs=processing_input, experiment_config=experiment_config)
260289
assert len(w) == 1
261290
assert issubclass(w[-1].category, UserWarning)
262291
assert "Running within a PipelineSession" in str(w[-1].message)
@@ -283,13 +312,21 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
283312
steps=[step, custom_step1, custom_step2],
284313
sagemaker_session=pipeline_session,
285314
)
315+
316+
expected_step_arguments = deepcopy(step_args.args)
317+
if expected_experiment_config is None:
318+
expected_step_arguments.pop("ExperimentConfig", None)
319+
else:
320+
expected_step_arguments["ExperimentConfig"] = expected_experiment_config
321+
del expected_step_arguments["ProcessingJobName"]
322+
286323
assert json.loads(pipeline.definition())["Steps"][0] == {
287324
"Name": "MyProcessingStep",
288325
"Description": "ProcessingStep description",
289326
"DisplayName": "MyProcessingStep",
290327
"Type": "Processing",
291328
"DependsOn": ["TestStep", "SecondTestStep"],
292-
"Arguments": step_args.args,
329+
"Arguments": expected_step_arguments,
293330
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
294331
"PropertyFiles": [
295332
{

tests/unit/sagemaker/workflow/test_training_step.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import pytest
2020
import warnings
2121

22+
from copy import deepcopy
23+
2224
from sagemaker import Processor, Model
2325
from sagemaker.parameter import IntegerParameter
2426
from sagemaker.transformer import Transformer
@@ -207,7 +209,34 @@ def hyperparameters():
207209
return {"test-key": "test-val"}
208210

209211

210-
def test_training_step_with_estimator(pipeline_session, training_input, hyperparameters):
212+
@pytest.mark.parametrize(
213+
"experiment_config, expected_experiment_config",
214+
[
215+
(
216+
{
217+
"ExperimentName": "experiment-name",
218+
"TrialName": "trial-name",
219+
"TrialComponentDisplayName": "display-name",
220+
},
221+
{"TrialComponentDisplayName": "display-name"},
222+
),
223+
(
224+
{"TrialComponentDisplayName": "display-name"},
225+
{"TrialComponentDisplayName": "display-name"},
226+
),
227+
(
228+
{
229+
"ExperimentName": "experiment-name",
230+
"TrialName": "trial-name",
231+
},
232+
None,
233+
),
234+
(None, None),
235+
],
236+
)
237+
def test_training_step_with_estimator(
238+
pipeline_session, training_input, hyperparameters, experiment_config, expected_experiment_config
239+
):
211240
custom_step1 = CustomStep("TestStep")
212241
custom_step2 = CustomStep("SecondTestStep")
213242
enable_network_isolation = ParameterBoolean(name="enable_network_isolation")
@@ -226,7 +255,9 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
226255
with warnings.catch_warnings(record=True) as w:
227256
# TODO: remove job_name once we merge
228257
# https://github.com/aws/sagemaker-python-sdk/pull/3158/files
229-
step_args = estimator.fit(inputs=training_input, job_name="TestJob")
258+
step_args = estimator.fit(
259+
inputs=training_input, job_name="TestJob", experiment_config=experiment_config
260+
)
230261
assert len(w) == 1
231262
assert issubclass(w[-1].category, UserWarning)
232263
assert "Running within a PipelineSession" in str(w[-1].message)
@@ -247,17 +278,28 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
247278
parameters=[enable_network_isolation, encrypt_container_traffic],
248279
sagemaker_session=pipeline_session,
249280
)
250-
step_args.args["EnableInterContainerTrafficEncryption"] = {
281+
282+
expected_step_arguments = deepcopy(step_args.args)
283+
284+
expected_step_arguments["EnableInterContainerTrafficEncryption"] = {
251285
"Get": "Parameters.encrypt_container_traffic"
252286
}
253-
step_args.args["EnableNetworkIsolation"] = {"Get": "Parameters.encrypt_container_traffic"}
287+
expected_step_arguments["EnableNetworkIsolation"] = {
288+
"Get": "Parameters.enable_network_isolation"
289+
}
290+
if expected_experiment_config is None:
291+
expected_step_arguments.pop("ExperimentConfig", None)
292+
else:
293+
expected_step_arguments["ExperimentConfig"] = expected_experiment_config
294+
del expected_step_arguments["TrainingJobName"]
295+
254296
assert json.loads(pipeline.definition())["Steps"][0] == {
255297
"Name": "MyTrainingStep",
256298
"Description": "TrainingStep description",
257299
"DisplayName": "MyTrainingStep",
258300
"Type": "Training",
259301
"DependsOn": ["TestStep", "SecondTestStep"],
260-
"Arguments": step_args.args,
302+
"Arguments": expected_step_arguments,
261303
}
262304
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
263305
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list

0 commit comments

Comments
 (0)