Skip to content

feature: add pipeline experiment config #2331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 4 additions & 33 deletions src/sagemaker/workflow/execution_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,56 +13,27 @@
"""Pipeline parameters and conditions for workflow."""
from __future__ import absolute_import

from typing import Dict

from sagemaker.workflow.entities import (
Entity,
Expression,
RequestType,
)


class ExecutionVariable(Entity, str):
class ExecutionVariable(Expression):
"""Pipeline execution variables for workflow."""

def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
"""Subclass str"""
value = ""
if len(args) == 1:
value = args[0] or value
elif kwargs:
value = kwargs.get("name", value)
return str.__new__(cls, ExecutionVariable._expr(value))

def __init__(self, name: str):
"""Create a pipeline execution variable.

Args:
name (str): The name of the execution variable.
"""
super(ExecutionVariable, self).__init__()
self.name = name

def __hash__(self):
"""Hash function for execution variable types"""
return hash(tuple(self.to_request()))

def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
return self.expr

@property
def expr(self) -> Dict[str, str]:
def expr(self) -> RequestType:
"""The 'Get' expression dict for an `ExecutionVariable`."""
return ExecutionVariable._expr(self.name)

@classmethod
def _expr(cls, name):
"""An internal classmethod for the 'Get' expression dict for an `ExecutionVariable`.

Args:
name (str): The name of the execution variable.
"""
return {"Get": f"Execution.{name}"}
return {"Get": f"Execution.{self.name}"}


class ExecutionVariables:
Expand Down
32 changes: 23 additions & 9 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json

from copy import deepcopy
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Sequence, Union, Optional

import attr
import botocore
Expand All @@ -30,7 +30,9 @@
Expression,
RequestType,
)
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import Step
from sagemaker.workflow.step_collections import StepCollection
Expand All @@ -44,6 +46,12 @@ class Pipeline(Entity):
Attributes:
name (str): The name of the pipeline.
parameters (Sequence[Parameters]): The list of the parameters.
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
the workflow will attempt to create an experiment and trial before
executing the steps. Creation will be skipped if an experiment or a trial with
the same name already exists. By default, pipeline name is used as
Comment on lines +51 to +52
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially the same trial could be used for multiple executions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If users override the default config in that way.

experiment name and execution id is used as the trial name.
If set to None, no experiment or trial will be created automatically.
steps (Sequence[Steps]): The list of the non-conditional steps associated with the pipeline.
Any steps that are within the
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
Expand All @@ -57,6 +65,11 @@ class Pipeline(Entity):

name: str = attr.ib(factory=str)
parameters: Sequence[Parameter] = attr.ib(factory=list)
pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib(
default=PipelineExperimentConfig(
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
)
)
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
sagemaker_session: Session = attr.ib(factory=Session)

Expand All @@ -69,22 +82,23 @@ def to_request(self) -> RequestType:
"Version": self._version,
"Metadata": self._metadata,
"Parameters": list_to_request(self.parameters),
"PipelineExperimentConfig": self.pipeline_experiment_config.to_request()
if self.pipeline_experiment_config is not None
else None,
"Steps": list_to_request(self.steps),
}

def create(
self,
role_arn: str,
description: str = None,
experiment_name: str = None,
tags: List[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""Creates a Pipeline in the Pipelines service.

Args:
role_arn (str): The role arn that is assumed by the pipeline to create step artifacts.
description (str): A description of the pipeline.
experiment_name (str): The name of the experiment.
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
tags.

Expand All @@ -96,7 +110,6 @@ def create(
kwargs = self._create_args(role_arn, description)
update_args(
kwargs,
ExperimentName=experiment_name,
Tags=tags,
)
return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs)
Expand All @@ -106,7 +119,7 @@ def _create_args(self, role_arn: str, description: str):

Args:
role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
pipeline_description (str): A description of the pipeline.
description (str): A description of the pipeline.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - looks like this is just a doc change but ideally doc would be checked statically to avoid this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do run docstyle, pylint. All passed with that error. Maybe some misconfiguration or docstyle doesn't check that at all.


Returns:
A keyword argument dict for calling create_pipeline.
Expand Down Expand Up @@ -147,23 +160,21 @@ def upsert(
self,
role_arn: str,
description: str = None,
experiment_name: str = None,
tags: List[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""Creates a pipeline or updates it, if it already exists.

Args:
role_arn (str): The role arn that is assumed by workflow to create step artifacts.
pipeline_description (str): A description of the pipeline.
experiment_name (str): The name of the experiment.
description (str): A description of the pipeline.
tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
tags.

Returns:
response dict from service
"""
try:
response = self.create(role_arn, description, experiment_name, tags)
response = self.create(role_arn, description, tags)
except ClientError as e:
error = e.response["Error"]
if (
Expand Down Expand Up @@ -224,6 +235,9 @@ def start(
def definition(self) -> str:
"""Converts a request structure to string representation for workflow service calls."""
request_dict = self.to_request()
request_dict["PipelineExperimentConfig"] = interpolate(
request_dict["PipelineExperimentConfig"]
)
request_dict["Steps"] = interpolate(request_dict["Steps"])

return json.dumps(request_dict)
Expand Down
76 changes: 76 additions & 0 deletions src/sagemaker/workflow/pipeline_experiment_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Pipeline experiment config for SageMaker pipeline."""
from __future__ import absolute_import

from typing import Union

from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.execution_variables import ExecutionVariable
from sagemaker.workflow.entities import (
Entity,
Expression,
RequestType,
)


class PipelineExperimentConfig(Entity):
"""Experiment config for SageMaker pipeline."""

def __init__(
self,
experiment_name: Union[str, Parameter, ExecutionVariable, Expression],
trial_name: Union[str, Parameter, ExecutionVariable, Expression],
):
"""Create a PipelineExperimentConfig

Args:
experiment_name: the name of the experiment that will be created
trial_name: the name of the trial that will be created
"""
self.experiment_name = experiment_name
self.trial_name = trial_name

def to_request(self) -> RequestType:
"""Returns: the request structure."""

return {
"ExperimentName": self.experiment_name,
"TrialName": self.trial_name,
}


class PipelineExperimentConfigProperty(Expression):
"""Reference to pipeline experiment config property."""

def __init__(self, name: str):
"""Create a reference to pipeline experiment property.

Args:
name (str): The name of the pipeline experiment config property.
"""
super(PipelineExperimentConfigProperty, self).__init__()
self.name = name

@property
def expr(self) -> RequestType:
"""The 'Get' expression dict for a pipeline experiment config property."""

return {"Get": f"PipelineExperimentConfig.{self.name}"}


class PipelineExperimentConfigProperties:
"""Enum-like class for all pipeline experiment config property references."""

EXPERIMENT_NAME = PipelineExperimentConfigProperty("ExperimentName")
TRIAL_NAME = PipelineExperimentConfigProperty("TrialName")
Loading