Skip to content

Commit 73e21a0

Browse files
committed
feature: add pipeline experiment config
1 parent 5333974 commit 73e21a0

File tree

7 files changed

+452
-50
lines changed

7 files changed

+452
-50
lines changed

src/sagemaker/workflow/execution_variables.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,56 +13,27 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import Dict
17-
1816
from sagemaker.workflow.entities import (
19-
Entity,
17+
Expression,
2018
RequestType,
2119
)
2220

2321

24-
class ExecutionVariable(Entity, str):
22+
class ExecutionVariable(Expression):
2523
"""Pipeline execution variables for workflow."""
2624

27-
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
28-
"""Subclass str"""
29-
value = ""
30-
if len(args) == 1:
31-
value = args[0] or value
32-
elif kwargs:
33-
value = kwargs.get("name", value)
34-
return str.__new__(cls, ExecutionVariable._expr(value))
35-
3625
def __init__(self, name: str):
3726
"""Create a pipeline execution variable.
3827
3928
Args:
4029
name (str): The name of the execution variable.
4130
"""
42-
super(ExecutionVariable, self).__init__()
4331
self.name = name
4432

45-
def __hash__(self):
46-
"""Hash function for execution variable types"""
47-
return hash(tuple(self.to_request()))
48-
49-
def to_request(self) -> RequestType:
50-
"""Get the request structure for workflow service calls."""
51-
return self.expr
52-
5333
@property
54-
def expr(self) -> Dict[str, str]:
34+
def expr(self) -> RequestType:
5535
"""The 'Get' expression dict for an `ExecutionVariable`."""
56-
return ExecutionVariable._expr(self.name)
57-
58-
@classmethod
59-
def _expr(cls, name):
60-
"""An internal classmethod for the 'Get' expression dict for an `ExecutionVariable`.
61-
62-
Args:
63-
name (str): The name of the execution variable.
64-
"""
65-
return {"Get": f"Execution.{name}"}
36+
return {"Get": f"Execution.{self.name}"}
6637

6738

6839
class ExecutionVariables:

src/sagemaker/workflow/pipeline.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717

1818
from copy import deepcopy
19-
from typing import Any, Dict, List, Sequence, Union
19+
from typing import Any, Dict, List, Sequence, Union, Optional
2020

2121
import attr
2222
import botocore
@@ -30,7 +30,9 @@
3030
Expression,
3131
RequestType,
3232
)
33+
from sagemaker.workflow.execution_variables import ExecutionVariables
3334
from sagemaker.workflow.parameters import Parameter
35+
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3436
from sagemaker.workflow.properties import Properties
3537
from sagemaker.workflow.steps import Step
3638
from sagemaker.workflow.step_collections import StepCollection
@@ -44,6 +46,11 @@ class Pipeline(Entity):
4446
Attributes:
4547
name (str): The name of the pipeline.
4648
parameters (Sequence[Parameters]): The list of the parameters.
49+
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
50+
the workflow will attempt to create an experiment and trial before
51+
executing the steps. Creation will be skipped if an experiment or a trial with
52+
the same name already exists.
53+
If set to None, no experiment or trial will be created automatically.
4754
steps (Sequence[Steps]): The list of the non-conditional steps associated with the pipeline.
4855
Any steps that are within the
4956
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
@@ -57,6 +64,11 @@ class Pipeline(Entity):
5764

5865
name: str = attr.ib(factory=str)
5966
parameters: Sequence[Parameter] = attr.ib(factory=list)
67+
pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib(
68+
default=PipelineExperimentConfig(
69+
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
70+
)
71+
)
6072
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
6173
sagemaker_session: Session = attr.ib(factory=Session)
6274

@@ -69,22 +81,22 @@ def to_request(self) -> RequestType:
6981
"Version": self._version,
7082
"Metadata": self._metadata,
7183
"Parameters": list_to_request(self.parameters),
84+
"PipelineExperimentConfig":
85+
self.pipeline_experiment_config.to_request if self.pipeline_experiment_config is not None else None,
7286
"Steps": list_to_request(self.steps),
7387
}
7488

7589
def create(
7690
self,
7791
role_arn: str,
7892
description: str = None,
79-
experiment_name: str = None,
8093
tags: List[Dict[str, str]] = None,
8194
) -> Dict[str, Any]:
8295
"""Creates a Pipeline in the Pipelines service.
8396
8497
Args:
8598
role_arn (str): The role arn that is assumed by the pipeline to create step artifacts.
8699
description (str): A description of the pipeline.
87-
experiment_name (str): The name of the experiment.
88100
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
89101
tags.
90102
@@ -96,7 +108,6 @@ def create(
96108
kwargs = self._create_args(role_arn, description)
97109
update_args(
98110
kwargs,
99-
ExperimentName=experiment_name,
100111
Tags=tags,
101112
)
102113
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
@@ -106,7 +117,7 @@ def _create_args(self, role_arn: str, description: str):
106117
107118
Args:
108119
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
109-
pipeline_description (str): A description of the pipeline.
120+
description (str): A description of the pipeline.
110121
111122
Returns:
112123
A keyword argument dict for calling create_pipeline.
@@ -147,23 +158,21 @@ def upsert(
147158
self,
148159
role_arn: str,
149160
description: str = None,
150-
experiment_name: str = None,
151161
tags: List[Dict[str, str]] = None,
152162
) -> Dict[str, Any]:
153163
"""Creates a pipeline or updates it, if it already exists.
154164
155165
Args:
156166
role_arn (str): The role arn that is assumed by workflow to create step artifacts.
157-
pipeline_description (str): A description of the pipeline.
158-
experiment_name (str): The name of the experiment.
167+
description (str): A description of the pipeline.
159168
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
160169
tags.
161170
162171
Returns:
163172
response dict from service
164173
"""
165174
try:
166-
response = self.create(role_arn, description, experiment_name, tags)
175+
response = self.create(role_arn, description, tags)
167176
except ClientError as e:
168177
error = e.response["Error"]
169178
if (
@@ -224,6 +233,7 @@ def start(
224233
def definition(self) -> str:
225234
"""Converts a request structure to string representation for workflow service calls."""
226235
request_dict = self.to_request()
236+
request_dict["PipelineExperimentConfig"] = interpolate(request_dict["PipelineExperimentConfig"])
227237
request_dict["Steps"] = interpolate(request_dict["Steps"])
228238

229239
return json.dumps(request_dict)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline experiment config for SageMaker pipeline."""
14+
from __future__ import absolute_import
15+
16+
from typing import Union
17+
18+
from sagemaker.workflow.parameters import Parameter
19+
from sagemaker.workflow.execution_variables import ExecutionVariable
20+
from sagemaker.workflow.entities import (
21+
Entity,
22+
Expression,
23+
RequestType,
24+
)
25+
26+
27+
class PipelineExperimentConfig(Entity):
28+
"""Experiment config for SageMaker pipeline."""
29+
30+
def __init__(
31+
self,
32+
experiment_name: Union[str, Parameter, ExecutionVariable, Expression],
33+
trial_name: Union[str, Parameter, ExecutionVariable, Expression],
34+
):
35+
"""Create a PipelineExperimentConfig
36+
37+
Args:
38+
experiment_name: the name of the experiment that will be created
39+
trial_name: the name of the trial that will be created
40+
"""
41+
self.experiment_name = experiment_name
42+
self.trial_name = trial_name
43+
44+
@property
45+
def to_request(self) -> RequestType:
46+
"""
47+
48+
Returns: the request structure.
49+
"""
50+
return {
51+
"ExperimentName": self.experiment_name,
52+
"TrialName": self.trial_name,
53+
}
54+
55+
56+
class PipelineExperimentConfigProperty(Expression):
57+
"""Reference to pipeline experiment config property."""
58+
59+
def __init__(self, name: str):
60+
"""Create a reference to pipeline experiment property.
61+
62+
Args:
63+
name (str): The name of the pipeline experiment config property.
64+
"""
65+
super(PipelineExperimentConfigProperty, self).__init__()
66+
self.name = name
67+
68+
@property
69+
def expr(self) -> RequestType:
70+
"""The 'Get' expression dict for a pipeline experiment config property."""
71+
72+
return {"Get": f"PipelineExperimentConfig.{self.name}"}
73+
74+
75+
class PipelineExperimentConfigProperties:
76+
"""Enum-like class for all pipeline experiment config property references."""
77+
78+
EXPERIMENT_NAME = PipelineExperimentConfigProperty("ExperimentName")
79+
TRIAL_NAME = PipelineExperimentConfigProperty("TrialName")

0 commit comments

Comments
 (0)