Skip to content

Commit eaf35a7

Browse files
committed
Allow users to customize trial component display names for pipeline launched jobs
1 parent a5464a2 commit eaf35a7

10 files changed

+238
-27
lines changed

doc/amazon_sagemaker_model_building_pipeline.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,8 @@ There are a number of properties for a pipeline execution that can only be resol
741741
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_EXECUTION_ARN`: The execution ARN for an execution.
742742
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_NAME`: The name of the pipeline.
743743
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_ARN`: The ARN of the pipeline.
744+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.TRAINING_JOB_NAME`: The name of the training job launched by the training step.
745+
- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PROCESSING_JOB_NAME`: The name of the processing job launched by the processing step.
744746
745747
You can use these execution variables as you see fit. The following example uses the :code:`START_DATETIME` execution variable to construct a processing output path:
746748

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Execution Variables
5252
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariable
5353

5454
.. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariables
55-
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN
55+
:members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN, TRAINING_JOB_NAME, PROCESSING_JOB_NAME
5656

5757
Functions
5858
---------

src/sagemaker/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,12 @@ def fit(
10001000
* If both `ExperimentName` and `TrialName` are not supplied the trial component
10011001
will be unassociated.
10021002
* `TrialComponentDisplayName` is used for display in Studio.
1003+
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
1004+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
1005+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
1006+
Returns:
1007+
None or pipeline step arguments in case the Estimator instance is built with
1008+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
10031009
"""
10041010
self._prepare_for_training(job_name=job_name)
10051011

src/sagemaker/processing.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,14 @@ def run(
173173
* If both `ExperimentName` and `TrialName` are not supplied the trial component
174174
will be unassociated.
175175
* `TrialComponentDisplayName` is used for display in Studio.
176+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
177+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
178+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
176179
kms_key (str): The ARN of the KMS key that is used to encrypt the
177180
user code file (default: None).
178-
181+
Returns:
182+
None or pipeline step arguments in case the Processor instance is built with
183+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
179184
Raises:
180185
ValueError: if ``logs`` is True but ``wait`` is False.
181186
"""
@@ -543,8 +548,14 @@ def run(
543548
* If both `ExperimentName` and `TrialName` are not supplied the trial component
544549
will be unassociated.
545550
* `TrialComponentDisplayName` is used for display in Studio.
551+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
552+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
553+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
546554
kms_key (str): The ARN of the KMS key that is used to encrypt the
547555
user code file (default: None).
556+
Returns:
557+
None or pipeline step arguments in case the Processor instance is built with
558+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
548559
"""
549560
normalized_inputs, normalized_outputs = self._normalize_args(
550561
job_name=job_name,
@@ -1601,8 +1612,14 @@ def run( # type: ignore[override]
16011612
* If both `ExperimentName` and `TrialName` are not supplied the trial component
16021613
will be unassociated.
16031614
* `TrialComponentDisplayName` is used for display in Studio.
1615+
* Both `ExperimentName` and `TrialName` will be ignored if the Processor instance
1616+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
1617+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
16041618
kms_key (str): The ARN of the KMS key that is used to encrypt the
16051619
user code file (default: None).
1620+
Returns:
1621+
None or pipeline step arguments in case the Processor instance is built with
1622+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
16061623
"""
16071624
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
16081625
code, source_dir, dependencies, git_config, job_name, inputs

src/sagemaker/transformer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ def transform(
186186
* If both `ExperimentName` and `TrialName` are not supplied the trial component
187187
will be unassociated.
188188
* `TrialComponentDisplayName` is used for display in Studio.
189+
* Both `ExperimentName` and `TrialName` will be ignored if the Transformer instance
190+
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
191+
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
189192
model_client_config (dict[str, str]): Model configuration.
190193
Dictionary contains two optional keys,
191194
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
@@ -194,6 +197,11 @@ def transform(
194197
(default: ``True``).
195198
logs (bool): Whether to show the logs produced by the job.
196199
Only meaningful when wait is ``True`` (default: ``True``).
200+
kms_key (str): The ARN of the KMS key that is used to encrypt the
201+
user code file (default: None).
202+
Returns:
203+
None or pipeline step arguments in case the Transformer instance is built with
204+
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
197205
"""
198206
local_mode = self.sagemaker_session.local_mode
199207
if not local_mode and not is_pipeline_variable(data) and not data.startswith("s3://"):

src/sagemaker/workflow/execution_variables.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class ExecutionVariables:
5858
- ExecutionVariables.PIPELINE_ARN
5959
- ExecutionVariables.PIPELINE_EXECUTION_ID
6060
- ExecutionVariables.PIPELINE_EXECUTION_ARN
61+
- ExecutionVariables.TRAINING_JOB_NAME
62+
- ExecutionVariables.PROCESSING_JOB_NAME
6163
"""
6264

6365
START_DATETIME = ExecutionVariable("StartDateTime")
@@ -66,3 +68,5 @@ class ExecutionVariables:
6668
PIPELINE_ARN = ExecutionVariable("PipelineArn")
6769
PIPELINE_EXECUTION_ID = ExecutionVariable("PipelineExecutionId")
6870
PIPELINE_EXECUTION_ARN = ExecutionVariable("PipelineExecutionArn")
71+
TRAINING_JOB_NAME = ExecutionVariable("TrainingJobName")
72+
PROCESSING_JOB_NAME = ExecutionVariable("ProcessingJobName")

src/sagemaker/workflow/steps.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ def _get_step_name_from_str(
223223
return step_map[str_input].steps[-1].name
224224
return str_input
225225

226+
@staticmethod
227+
def _trim_experiment_config(request_dict: Dict):
228+
"""For job steps, trim the experiment config to keep the trial component display name."""
229+
if request_dict.get("ExperimentConfig", {}).get("TrialComponentDisplayName"):
230+
request_dict["ExperimentConfig"] = {
231+
"TrialComponentDisplayName": request_dict["ExperimentConfig"][
232+
"TrialComponentDisplayName"
233+
]
234+
}
235+
else:
236+
request_dict.pop("ExperimentConfig", None)
237+
226238

227239
@attr.s
228240
class CacheConfig:
@@ -429,7 +441,7 @@ def arguments(self) -> RequestType:
429441
request_dict["HyperParameters"].pop("sagemaker_job_name", None)
430442

431443
request_dict.pop("TrainingJobName", None)
432-
request_dict.pop("ExperimentConfig", None)
444+
Step._trim_experiment_config(request_dict)
433445

434446
return request_dict
435447

@@ -660,7 +672,8 @@ def arguments(self) -> RequestType:
660672
)
661673

662674
request_dict.pop("TransformJobName", None)
663-
request_dict.pop("ExperimentConfig", None)
675+
Step._trim_experiment_config(request_dict)
676+
664677
return request_dict
665678

666679
@property
@@ -808,7 +821,8 @@ def arguments(self) -> RequestType:
808821
request_dict = self.processor.sagemaker_session._get_process_request(**process_args)
809822

810823
request_dict.pop("ProcessingJobName", None)
811-
request_dict.pop("ExperimentConfig", None)
824+
Step._trim_experiment_config(request_dict)
825+
812826
return request_dict
813827

814828
@property

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import pytest
1919
import warnings
2020

21+
from copy import deepcopy
22+
2123
from sagemaker.estimator import Estimator
2224
from sagemaker.parameter import IntegerParameter
2325
from sagemaker.transformer import Transformer
@@ -268,7 +270,34 @@ def network_config():
268270
)
269271

270272

271-
def test_processing_step_with_processor(pipeline_session, processing_input):
273+
@pytest.mark.parametrize(
274+
"experiment_config, expected_experiment_config",
275+
[
276+
(
277+
{
278+
"ExperimentName": "experiment-name",
279+
"TrialName": "trial-name",
280+
"TrialComponentDisplayName": "display-name",
281+
},
282+
{"TrialComponentDisplayName": "display-name"},
283+
),
284+
(
285+
{"TrialComponentDisplayName": "display-name"},
286+
{"TrialComponentDisplayName": "display-name"},
287+
),
288+
(
289+
{
290+
"ExperimentName": "experiment-name",
291+
"TrialName": "trial-name",
292+
},
293+
None,
294+
),
295+
(None, None),
296+
],
297+
)
298+
def test_processing_step_with_processor(
299+
pipeline_session, processing_input, experiment_config, expected_experiment_config
300+
):
272301
custom_step1 = CustomStep("TestStep")
273302
custom_step2 = CustomStep("SecondTestStep")
274303
processor = Processor(
@@ -280,7 +309,7 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
280309
)
281310

282311
with warnings.catch_warnings(record=True) as w:
283-
step_args = processor.run(inputs=processing_input)
312+
step_args = processor.run(inputs=processing_input, experiment_config=experiment_config)
284313
assert len(w) == 1
285314
assert issubclass(w[-1].category, UserWarning)
286315
assert "Running within a PipelineSession" in str(w[-1].message)
@@ -307,13 +336,21 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
307336
steps=[step, custom_step1, custom_step2],
308337
sagemaker_session=pipeline_session,
309338
)
339+
340+
expected_step_arguments = deepcopy(step_args.args)
341+
if expected_experiment_config is None:
342+
expected_step_arguments.pop("ExperimentConfig", None)
343+
else:
344+
expected_step_arguments["ExperimentConfig"] = expected_experiment_config
345+
del expected_step_arguments["ProcessingJobName"]
346+
310347
assert json.loads(pipeline.definition())["Steps"][0] == {
311348
"Name": "MyProcessingStep",
312349
"Description": "ProcessingStep description",
313350
"DisplayName": "MyProcessingStep",
314351
"Type": "Processing",
315352
"DependsOn": ["TestStep", "SecondTestStep"],
316-
"Arguments": step_args.args,
353+
"Arguments": expected_step_arguments,
317354
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
318355
"PropertyFiles": [
319356
{

tests/unit/sagemaker/workflow/test_training_step.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import pytest
2020
import warnings
2121

22+
from copy import deepcopy
23+
2224
from sagemaker import Processor, Model
2325
from sagemaker.parameter import IntegerParameter
2426
from sagemaker.transformer import Transformer
@@ -200,7 +202,34 @@ def hyperparameters():
200202
return {"test-key": "test-val"}
201203

202204

203-
def test_training_step_with_estimator(pipeline_session, training_input, hyperparameters):
205+
@pytest.mark.parametrize(
206+
"experiment_config, expected_experiment_config",
207+
[
208+
(
209+
{
210+
"ExperimentName": "experiment-name",
211+
"TrialName": "trial-name",
212+
"TrialComponentDisplayName": "display-name",
213+
},
214+
{"TrialComponentDisplayName": "display-name"},
215+
),
216+
(
217+
{"TrialComponentDisplayName": "display-name"},
218+
{"TrialComponentDisplayName": "display-name"},
219+
),
220+
(
221+
{
222+
"ExperimentName": "experiment-name",
223+
"TrialName": "trial-name",
224+
},
225+
None,
226+
),
227+
(None, None),
228+
],
229+
)
230+
def test_training_step_with_estimator(
231+
pipeline_session, training_input, hyperparameters, experiment_config, expected_experiment_config
232+
):
204233
custom_step1 = CustomStep("TestStep")
205234
custom_step2 = CustomStep("SecondTestStep")
206235
enable_network_isolation = ParameterBoolean(name="enable_network_isolation")
@@ -217,7 +246,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
217246
)
218247

219248
with warnings.catch_warnings(record=True) as w:
220-
step_args = estimator.fit(inputs=training_input)
249+
step_args = estimator.fit(inputs=training_input, experiment_config=experiment_config)
221250
assert len(w) == 1
222251
assert issubclass(w[-1].category, UserWarning)
223252
assert "Running within a PipelineSession" in str(w[-1].message)
@@ -238,17 +267,28 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
238267
parameters=[enable_network_isolation, encrypt_container_traffic],
239268
sagemaker_session=pipeline_session,
240269
)
241-
step_args.args["EnableInterContainerTrafficEncryption"] = {
270+
271+
expected_step_arguments = deepcopy(step_args.args)
272+
273+
expected_step_arguments["EnableInterContainerTrafficEncryption"] = {
242274
"Get": "Parameters.encrypt_container_traffic"
243275
}
244-
step_args.args["EnableNetworkIsolation"] = {"Get": "Parameters.encrypt_container_traffic"}
276+
expected_step_arguments["EnableNetworkIsolation"] = {
277+
"Get": "Parameters.enable_network_isolation"
278+
}
279+
if expected_experiment_config is None:
280+
expected_step_arguments.pop("ExperimentConfig", None)
281+
else:
282+
expected_step_arguments["ExperimentConfig"] = expected_experiment_config
283+
del expected_step_arguments["TrainingJobName"]
284+
245285
assert json.loads(pipeline.definition())["Steps"][0] == {
246286
"Name": "MyTrainingStep",
247287
"Description": "TrainingStep description",
248288
"DisplayName": "MyTrainingStep",
249289
"Type": "Training",
250290
"DependsOn": ["TestStep", "SecondTestStep"],
251-
"Arguments": step_args.args,
291+
"Arguments": expected_step_arguments,
252292
}
253293
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
254294
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list

0 commit comments

Comments
 (0)