Skip to content

Commit f5ebe8d

Browse files
committed
feature: add pipeline experiment config
1 parent 5333974 commit f5ebe8d

File tree

7 files changed

+448
-50
lines changed

7 files changed

+448
-50
lines changed

src/sagemaker/workflow/execution_variables.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,56 +13,27 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import Dict
17-
1816
from sagemaker.workflow.entities import (
19-
Entity,
17+
Expression,
2018
RequestType,
2119
)
2220

2321

24-
class ExecutionVariable(Entity, str):
22+
class ExecutionVariable(Expression):
2523
"""Pipeline execution variables for workflow."""
2624

27-
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
28-
"""Subclass str"""
29-
value = ""
30-
if len(args) == 1:
31-
value = args[0] or value
32-
elif kwargs:
33-
value = kwargs.get("name", value)
34-
return str.__new__(cls, ExecutionVariable._expr(value))
35-
3625
def __init__(self, name: str):
3726
"""Create a pipeline execution variable.
3827
3928
Args:
4029
name (str): The name of the execution variable.
4130
"""
42-
super(ExecutionVariable, self).__init__()
4331
self.name = name
4432

45-
def __hash__(self):
46-
"""Hash function for execution variable types"""
47-
return hash(tuple(self.to_request()))
48-
49-
def to_request(self) -> RequestType:
50-
"""Get the request structure for workflow service calls."""
51-
return self.expr
52-
5333
@property
54-
def expr(self) -> Dict[str, str]:
34+
def expr(self) -> RequestType:
5535
"""The 'Get' expression dict for an `ExecutionVariable`."""
56-
return ExecutionVariable._expr(self.name)
57-
58-
@classmethod
59-
def _expr(cls, name):
60-
"""An internal classmethod for the 'Get' expression dict for an `ExecutionVariable`.
61-
62-
Args:
63-
name (str): The name of the execution variable.
64-
"""
65-
return {"Get": f"Execution.{name}"}
36+
return {"Get": f"Execution.{self.name}"}
6637

6738

6839
class ExecutionVariables:

src/sagemaker/workflow/pipeline.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717

1818
from copy import deepcopy
19-
from typing import Any, Dict, List, Sequence, Union
19+
from typing import Any, Dict, List, Sequence, Union, Optional
2020

2121
import attr
2222
import botocore
@@ -30,7 +30,9 @@
3030
Expression,
3131
RequestType,
3232
)
33+
from sagemaker.workflow.execution_variables import ExecutionVariables
3334
from sagemaker.workflow.parameters import Parameter
35+
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3436
from sagemaker.workflow.properties import Properties
3537
from sagemaker.workflow.steps import Step
3638
from sagemaker.workflow.step_collections import StepCollection
@@ -44,6 +46,11 @@ class Pipeline(Entity):
4446
Attributes:
4547
name (str): The name of the pipeline.
4648
parameters (Sequence[Parameters]): The list of the parameters.
49+
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
50+
the workflow will attempt to create an experiment and trial before
51+
executing the steps. Creation will be skipped if an experiment or a trial with
52+
the same name already exists.
53+
If set to None, no experiment or trial will be created automatically.
4754
steps (Sequence[Steps]): The list of the non-conditional steps associated with the pipeline.
4855
Any steps that are within the
4956
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
@@ -57,6 +64,11 @@ class Pipeline(Entity):
5764

5865
name: str = attr.ib(factory=str)
5966
parameters: Sequence[Parameter] = attr.ib(factory=list)
67+
pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib(
68+
default=PipelineExperimentConfig(
69+
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
70+
)
71+
)
6072
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
6173
sagemaker_session: Session = attr.ib(factory=Session)
6274

@@ -69,22 +81,23 @@ def to_request(self) -> RequestType:
6981
"Version": self._version,
7082
"Metadata": self._metadata,
7183
"Parameters": list_to_request(self.parameters),
84+
"PipelineExperimentConfig": self.pipeline_experiment_config.to_request()
85+
if self.pipeline_experiment_config is not None
86+
else None,
7287
"Steps": list_to_request(self.steps),
7388
}
7489

7590
def create(
7691
self,
7792
role_arn: str,
7893
description: str = None,
79-
experiment_name: str = None,
8094
tags: List[Dict[str, str]] = None,
8195
) -> Dict[str, Any]:
8296
"""Creates a Pipeline in the Pipelines service.
8397
8498
Args:
8599
role_arn (str): The role arn that is assumed by the pipeline to create step artifacts.
86100
description (str): A description of the pipeline.
87-
experiment_name (str): The name of the experiment.
88101
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
89102
tags.
90103
@@ -96,7 +109,6 @@ def create(
96109
kwargs = self._create_args(role_arn, description)
97110
update_args(
98111
kwargs,
99-
ExperimentName=experiment_name,
100112
Tags=tags,
101113
)
102114
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
@@ -106,7 +118,7 @@ def _create_args(self, role_arn: str, description: str):
106118
107119
Args:
108120
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
109-
pipeline_description (str): A description of the pipeline.
121+
description (str): A description of the pipeline.
110122
111123
Returns:
112124
A keyword argument dict for calling create_pipeline.
@@ -147,23 +159,21 @@ def upsert(
147159
self,
148160
role_arn: str,
149161
description: str = None,
150-
experiment_name: str = None,
151162
tags: List[Dict[str, str]] = None,
152163
) -> Dict[str, Any]:
153164
"""Creates a pipeline or updates it, if it already exists.
154165
155166
Args:
156167
role_arn (str): The role arn that is assumed by workflow to create step artifacts.
157-
pipeline_description (str): A description of the pipeline.
158-
experiment_name (str): The name of the experiment.
168+
description (str): A description of the pipeline.
159169
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
160170
tags.
161171
162172
Returns:
163173
response dict from service
164174
"""
165175
try:
166-
response = self.create(role_arn, description, experiment_name, tags)
176+
response = self.create(role_arn, description, tags)
167177
except ClientError as e:
168178
error = e.response["Error"]
169179
if (
@@ -224,6 +234,9 @@ def start(
224234
def definition(self) -> str:
225235
"""Converts a request structure to string representation for workflow service calls."""
226236
request_dict = self.to_request()
237+
request_dict["PipelineExperimentConfig"] = interpolate(
238+
request_dict["PipelineExperimentConfig"]
239+
)
227240
request_dict["Steps"] = interpolate(request_dict["Steps"])
228241

229242
return json.dumps(request_dict)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline experiment config for SageMaker pipeline."""
14+
from __future__ import absolute_import
15+
16+
from typing import Union
17+
18+
from sagemaker.workflow.parameters import Parameter
19+
from sagemaker.workflow.execution_variables import ExecutionVariable
20+
from sagemaker.workflow.entities import (
21+
Entity,
22+
Expression,
23+
RequestType,
24+
)
25+
26+
27+
class PipelineExperimentConfig(Entity):
28+
"""Experiment config for SageMaker pipeline."""
29+
30+
def __init__(
31+
self,
32+
experiment_name: Union[str, Parameter, ExecutionVariable, Expression],
33+
trial_name: Union[str, Parameter, ExecutionVariable, Expression],
34+
):
35+
"""Create a PipelineExperimentConfig
36+
37+
Args:
38+
experiment_name: the name of the experiment that will be created
39+
trial_name: the name of the trial that will be created
40+
"""
41+
self.experiment_name = experiment_name
42+
self.trial_name = trial_name
43+
44+
def to_request(self) -> RequestType:
45+
"""Returns: the request structure."""
46+
47+
return {
48+
"ExperimentName": self.experiment_name,
49+
"TrialName": self.trial_name,
50+
}
51+
52+
53+
class PipelineExperimentConfigProperty(Expression):
54+
"""Reference to pipeline experiment config property."""
55+
56+
def __init__(self, name: str):
57+
"""Create a reference to pipeline experiment property.
58+
59+
Args:
60+
name (str): The name of the pipeline experiment config property.
61+
"""
62+
super(PipelineExperimentConfigProperty, self).__init__()
63+
self.name = name
64+
65+
@property
66+
def expr(self) -> RequestType:
67+
"""The 'Get' expression dict for a pipeline experiment config property."""
68+
69+
return {"Get": f"PipelineExperimentConfig.{self.name}"}
70+
71+
72+
class PipelineExperimentConfigProperties:
73+
"""Enum-like class for all pipeline experiment config property references."""
74+
75+
EXPERIMENT_NAME = PipelineExperimentConfigProperty("ExperimentName")
76+
TRIAL_NAME = PipelineExperimentConfigProperty("TrialName")

0 commit comments

Comments
 (0)