Skip to content

Commit cb16a1b

Browse files
navaj0ajaykarpur
andauthored
feature: add pipeline experiment config (#2331)
Co-authored-by: Ajay Karpur <[email protected]>
1 parent 2a61c41 commit cb16a1b

File tree

7 files changed

+449
-50
lines changed

7 files changed

+449
-50
lines changed

src/sagemaker/workflow/execution_variables.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,56 +13,27 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import Dict
17-
1816
from sagemaker.workflow.entities import (
19-
Entity,
17+
Expression,
2018
RequestType,
2119
)
2220

2321

24-
class ExecutionVariable(Entity, str):
22+
class ExecutionVariable(Expression):
2523
"""Pipeline execution variables for workflow."""
2624

27-
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
28-
"""Subclass str"""
29-
value = ""
30-
if len(args) == 1:
31-
value = args[0] or value
32-
elif kwargs:
33-
value = kwargs.get("name", value)
34-
return str.__new__(cls, ExecutionVariable._expr(value))
35-
3625
def __init__(self, name: str):
3726
"""Create a pipeline execution variable.
3827
3928
Args:
4029
name (str): The name of the execution variable.
4130
"""
42-
super(ExecutionVariable, self).__init__()
4331
self.name = name
4432

45-
def __hash__(self):
46-
"""Hash function for execution variable types"""
47-
return hash(tuple(self.to_request()))
48-
49-
def to_request(self) -> RequestType:
50-
"""Get the request structure for workflow service calls."""
51-
return self.expr
52-
5333
@property
54-
def expr(self) -> Dict[str, str]:
34+
def expr(self) -> RequestType:
5535
"""The 'Get' expression dict for an `ExecutionVariable`."""
56-
return ExecutionVariable._expr(self.name)
57-
58-
@classmethod
59-
def _expr(cls, name):
60-
"""An internal classmethod for the 'Get' expression dict for an `ExecutionVariable`.
61-
62-
Args:
63-
name (str): The name of the execution variable.
64-
"""
65-
return {"Get": f"Execution.{name}"}
36+
return {"Get": f"Execution.{self.name}"}
6637

6738

6839
class ExecutionVariables:

src/sagemaker/workflow/pipeline.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717

1818
from copy import deepcopy
19-
from typing import Any, Dict, List, Sequence, Union
19+
from typing import Any, Dict, List, Sequence, Union, Optional
2020

2121
import attr
2222
import botocore
@@ -30,7 +30,9 @@
3030
Expression,
3131
RequestType,
3232
)
33+
from sagemaker.workflow.execution_variables import ExecutionVariables
3334
from sagemaker.workflow.parameters import Parameter
35+
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3436
from sagemaker.workflow.properties import Properties
3537
from sagemaker.workflow.steps import Step
3638
from sagemaker.workflow.step_collections import StepCollection
@@ -44,6 +46,12 @@ class Pipeline(Entity):
4446
Attributes:
4547
name (str): The name of the pipeline.
4648
parameters (Sequence[Parameters]): The list of the parameters.
49+
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
50+
the workflow will attempt to create an experiment and trial before
51+
executing the steps. Creation will be skipped if an experiment or a trial with
52+
the same name already exists. By default, pipeline name is used as
53+
experiment name and execution id is used as the trial name.
54+
If set to None, no experiment or trial will be created automatically.
4755
steps (Sequence[Steps]): The list of the non-conditional steps associated with the pipeline.
4856
Any steps that are within the
4957
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
@@ -57,6 +65,11 @@ class Pipeline(Entity):
5765

5866
name: str = attr.ib(factory=str)
5967
parameters: Sequence[Parameter] = attr.ib(factory=list)
68+
pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib(
69+
default=PipelineExperimentConfig(
70+
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
71+
)
72+
)
6073
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
6174
sagemaker_session: Session = attr.ib(factory=Session)
6275

@@ -69,22 +82,23 @@ def to_request(self) -> RequestType:
6982
"Version": self._version,
7083
"Metadata": self._metadata,
7184
"Parameters": list_to_request(self.parameters),
85+
"PipelineExperimentConfig": self.pipeline_experiment_config.to_request()
86+
if self.pipeline_experiment_config is not None
87+
else None,
7288
"Steps": list_to_request(self.steps),
7389
}
7490

7591
def create(
7692
self,
7793
role_arn: str,
7894
description: str = None,
79-
experiment_name: str = None,
8095
tags: List[Dict[str, str]] = None,
8196
) -> Dict[str, Any]:
8297
"""Creates a Pipeline in the Pipelines service.
8398
8499
Args:
85100
role_arn (str): The role arn that is assumed by the pipeline to create step artifacts.
86101
description (str): A description of the pipeline.
87-
experiment_name (str): The name of the experiment.
88102
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
89103
tags.
90104
@@ -96,7 +110,6 @@ def create(
96110
kwargs = self._create_args(role_arn, description)
97111
update_args(
98112
kwargs,
99-
ExperimentName=experiment_name,
100113
Tags=tags,
101114
)
102115
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
@@ -106,7 +119,7 @@ def _create_args(self, role_arn: str, description: str):
106119
107120
Args:
108121
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
109-
pipeline_description (str): A description of the pipeline.
122+
description (str): A description of the pipeline.
110123
111124
Returns:
112125
A keyword argument dict for calling create_pipeline.
@@ -147,23 +160,21 @@ def upsert(
147160
self,
148161
role_arn: str,
149162
description: str = None,
150-
experiment_name: str = None,
151163
tags: List[Dict[str, str]] = None,
152164
) -> Dict[str, Any]:
153165
"""Creates a pipeline or updates it, if it already exists.
154166
155167
Args:
156168
role_arn (str): The role arn that is assumed by workflow to create step artifacts.
157-
pipeline_description (str): A description of the pipeline.
158-
experiment_name (str): The name of the experiment.
169+
description (str): A description of the pipeline.
159170
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
160171
tags.
161172
162173
Returns:
163174
response dict from service
164175
"""
165176
try:
166-
response = self.create(role_arn, description, experiment_name, tags)
177+
response = self.create(role_arn, description, tags)
167178
except ClientError as e:
168179
error = e.response["Error"]
169180
if (
@@ -224,6 +235,9 @@ def start(
224235
def definition(self) -> str:
225236
"""Converts a request structure to string representation for workflow service calls."""
226237
request_dict = self.to_request()
238+
request_dict["PipelineExperimentConfig"] = interpolate(
239+
request_dict["PipelineExperimentConfig"]
240+
)
227241
request_dict["Steps"] = interpolate(request_dict["Steps"])
228242

229243
return json.dumps(request_dict)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline experiment config for SageMaker pipeline."""
14+
from __future__ import absolute_import
15+
16+
from typing import Union
17+
18+
from sagemaker.workflow.parameters import Parameter
19+
from sagemaker.workflow.execution_variables import ExecutionVariable
20+
from sagemaker.workflow.entities import (
21+
Entity,
22+
Expression,
23+
RequestType,
24+
)
25+
26+
27+
class PipelineExperimentConfig(Entity):
28+
"""Experiment config for SageMaker pipeline."""
29+
30+
def __init__(
31+
self,
32+
experiment_name: Union[str, Parameter, ExecutionVariable, Expression],
33+
trial_name: Union[str, Parameter, ExecutionVariable, Expression],
34+
):
35+
"""Create a PipelineExperimentConfig
36+
37+
Args:
38+
experiment_name: the name of the experiment that will be created
39+
trial_name: the name of the trial that will be created
40+
"""
41+
self.experiment_name = experiment_name
42+
self.trial_name = trial_name
43+
44+
def to_request(self) -> RequestType:
45+
"""Returns: the request structure."""
46+
47+
return {
48+
"ExperimentName": self.experiment_name,
49+
"TrialName": self.trial_name,
50+
}
51+
52+
53+
class PipelineExperimentConfigProperty(Expression):
54+
"""Reference to pipeline experiment config property."""
55+
56+
def __init__(self, name: str):
57+
"""Create a reference to pipeline experiment property.
58+
59+
Args:
60+
name (str): The name of the pipeline experiment config property.
61+
"""
62+
super(PipelineExperimentConfigProperty, self).__init__()
63+
self.name = name
64+
65+
@property
66+
def expr(self) -> RequestType:
67+
"""The 'Get' expression dict for a pipeline experiment config property."""
68+
69+
return {"Get": f"PipelineExperimentConfig.{self.name}"}
70+
71+
72+
class PipelineExperimentConfigProperties:
73+
"""Enum-like class for all pipeline experiment config property references."""
74+
75+
EXPERIMENT_NAME = PipelineExperimentConfigProperty("ExperimentName")
76+
TRIAL_NAME = PipelineExperimentConfigProperty("TrialName")

0 commit comments

Comments
 (0)