Skip to content

Commit 850972c

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
change: Implement override solution for pipeline variables (#2995)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 98bbea6 commit 850972c

17 files changed

+777
-88
lines changed

src/sagemaker/estimator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@
7474
get_config_value,
7575
name_from_base,
7676
)
77-
from sagemaker.workflow.entities import Expression
78-
from sagemaker.workflow.parameters import Parameter
79-
from sagemaker.workflow.properties import Properties
77+
from sagemaker.workflow.entities import PipelineVariable
8078

8179
logger = logging.getLogger(__name__)
8280

@@ -602,7 +600,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
602600
current_hyperparameters = hyperparameters
603601
if current_hyperparameters is not None:
604602
hyperparameters = {
605-
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v))
603+
str(k): (v.to_string() if isinstance(v, PipelineVariable) else json.dumps(v))
606604
for (k, v) in current_hyperparameters.items()
607605
}
608606
return hyperparameters
@@ -1813,7 +1811,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
18131811
current_hyperparameters = estimator.hyperparameters()
18141812
if current_hyperparameters is not None:
18151813
hyperparameters = {
1816-
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
1814+
str(k): (v.to_string() if isinstance(v, PipelineVariable) else str(v))
18171815
for (k, v) in current_hyperparameters.items()
18181816
}
18191817

src/sagemaker/parameter.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
from __future__ import absolute_import
1515

1616
import json
17-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
18-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
19-
from sagemaker.workflow.functions import Join as PipelineJoin
17+
18+
from sagemaker.workflow.entities import PipelineVariable
2019

2120

2221
class ParameterRange(object):
@@ -73,11 +72,11 @@ def as_tuning_range(self, name):
7372
return {
7473
"Name": name,
7574
"MinValue": str(self.min_value)
76-
if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
77-
else self.min_value,
75+
if not isinstance(self.min_value, PipelineVariable)
76+
else self.min_value.to_string(),
7877
"MaxValue": str(self.max_value)
79-
if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
80-
else self.max_value,
78+
if not isinstance(self.max_value, PipelineVariable)
79+
else self.max_value.to_string(),
8180
"ScalingType": self.scaling_type,
8281
}
8382

@@ -112,8 +111,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
112111
"""
113112
values = values if isinstance(values, list) else [values]
114113
self.values = [
115-
str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v
116-
for v in values
114+
str(v) if not isinstance(v, PipelineVariable) else v.to_string() for v in values
117115
]
118116

119117
def as_tuning_range(self, name):

src/sagemaker/tuner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
IntegerParameter,
3939
ParameterRange,
4040
)
41+
from sagemaker.workflow.entities import PipelineVariable
4142
from sagemaker.workflow.parameters import Parameter as PipelineParameter
4243
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
4344
from sagemaker.workflow.functions import Join as PipelineJoin
@@ -376,9 +377,7 @@ def _prepare_static_hyperparameters(
376377
"""Prepare static hyperparameters for one estimator before tuning."""
377378
# Remove any hyperparameter that will be tuned
378379
static_hyperparameters = {
379-
str(k): str(v)
380-
if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin))
381-
else v
380+
str(k): str(v) if not isinstance(v, PipelineVariable) else v.to_string()
382381
for (k, v) in estimator.hyperparameters().items()
383382
}
384383
for hyperparameter_name in hyperparameter_ranges.keys():

src/sagemaker/workflow/entities.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import abc
1717

1818
from enum import EnumMeta
19-
from typing import Any, Dict, List, Union
19+
from typing import Any, Dict, List, Union, Optional
2020

2121
PrimitiveType = Union[str, int, bool, float, None]
2222
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
@@ -57,3 +57,80 @@ class Expression(abc.ABC):
5757
@abc.abstractmethod
5858
def expr(self) -> RequestType:
5959
"""Get the expression structure for workflow service calls."""
60+
61+
62+
class PipelineVariable(Expression):
63+
"""Base object for pipeline variables
64+
65+
PipelineVariables must implement the expr property.
66+
"""
67+
68+
def __add__(self, other: Union[Expression, PrimitiveType]):
69+
"""Add function for PipelineVariable
70+
71+
Args:
72+
other (Union[Expression, PrimitiveType]): The other object to be concatenated.
73+
74+
Always raise an error since pipeline variables do not support concatenation
75+
"""
76+
77+
raise TypeError("Pipeline variables do not support concatenation.")
78+
79+
def __str__(self):
80+
"""Override built-in String function for PipelineVariable"""
81+
raise TypeError("Pipeline variables do not support __str__ operation.")
82+
83+
def __int__(self):
84+
"""Override built-in Integer function for PipelineVariable"""
85+
raise TypeError("Pipeline variables do not support __int__ operation.")
86+
87+
def __float__(self):
88+
"""Override built-in Float function for PipelineVariable"""
89+
raise TypeError("Pipeline variables do not support __float__ operation.")
90+
91+
def to_string(self):
92+
"""Prompt the pipeline to convert the pipeline variable to String in runtime"""
93+
from sagemaker.workflow.functions import Join
94+
95+
return Join(on="", values=[self])
96+
97+
@property
98+
@abc.abstractmethod
99+
def expr(self) -> RequestType:
100+
"""Get the expression structure for workflow service calls."""
101+
102+
def startswith(
103+
self,
104+
prefix: Union[str, tuple], # pylint: disable=unused-argument
105+
start: Optional[int] = None, # pylint: disable=unused-argument
106+
end: Optional[int] = None, # pylint: disable=unused-argument
107+
) -> bool:
108+
"""Simulate the Python string's built-in method: startswith
109+
110+
Args:
111+
prefix (str, tuple): The (tuple of) string to be checked.
112+
start (int): To set the start index of the matching boundary (default: None).
113+
end (int): To set the end index of the matching boundary (default: None).
114+
115+
Return:
116+
bool: Always return False as Pipeline variables are parsed during execution runtime
117+
"""
118+
return False
119+
120+
def endswith(
121+
self,
122+
suffix: Union[str, tuple], # pylint: disable=unused-argument
123+
start: Optional[int] = None, # pylint: disable=unused-argument
124+
end: Optional[int] = None, # pylint: disable=unused-argument
125+
) -> bool:
126+
"""Simulate the Python string's built-in method: endswith
127+
128+
Args:
129+
suffix (str, tuple): The (tuple of) string to be checked.
130+
start (int): To set the start index of the matching boundary (default: None).
131+
end (int): To set the end index of the matching boundary (default: None).
132+
133+
Return:
134+
bool: Always return False as Pipeline variables are parsed during execution runtime
135+
"""
136+
return False

src/sagemaker/workflow/execution_variables.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.workflow.entities import (
17-
Expression,
1817
RequestType,
18+
PipelineVariable,
1919
)
2020

2121

22-
class ExecutionVariable(Expression):
22+
class ExecutionVariable(PipelineVariable):
2323
"""Pipeline execution variables for workflow."""
2424

2525
def __init__(self, name: str):
@@ -30,6 +30,13 @@ def __init__(self, name: str):
3030
"""
3131
self.name = name
3232

33+
def to_string(self) -> PipelineVariable:
34+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
35+
36+
As ExecutionVariable is treated as String in runtime, no extra actions are needed.
37+
"""
38+
return self
39+
3340
@property
3441
def expr(self) -> RequestType:
3542
"""The 'Get' expression dict for an `ExecutionVariable`."""

src/sagemaker/workflow/functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
import attr
1919

20-
from sagemaker.workflow.entities import Expression
20+
from sagemaker.workflow.entities import PipelineVariable
2121
from sagemaker.workflow.properties import PropertyFile
2222

2323

2424
@attr.s
25-
class Join(Expression):
25+
class Join(PipelineVariable):
2626
"""Join together properties.
2727
2828
Examples:
@@ -38,15 +38,23 @@ class Join(Expression):
3838
Attributes:
3939
values (List[Union[PrimitiveType, Parameter, Expression]]):
4040
The primitive type values, parameters, step properties, expressions to join.
41-
on_str (str): The string to join the values on (Defaults to "").
41+
on (str): The string to join the values on (Defaults to "").
4242
"""
4343

4444
on: str = attr.ib(factory=str)
4545
values: List = attr.ib(factory=list)
4646

47+
def to_string(self) -> PipelineVariable:
48+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
49+
50+
As Join is treated as String in runtime, no extra actions are needed.
51+
"""
52+
return self
53+
4754
@property
4855
def expr(self):
4956
"""The expression dict for a `Join` function."""
57+
5058
return {
5159
"Std:Join": {
5260
"On": self.on,
@@ -58,7 +66,7 @@ def expr(self):
5866

5967

6068
@attr.s
61-
class JsonGet(Expression):
69+
class JsonGet(PipelineVariable):
6270
"""Get JSON properties from PropertyFiles.
6371
6472
Attributes:

src/sagemaker/workflow/parameters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Entity,
2525
PrimitiveType,
2626
RequestType,
27+
PipelineVariable,
2728
)
2829

2930

@@ -48,7 +49,7 @@ def python_type(self) -> Type:
4849

4950

5051
@attr.s
51-
class Parameter(Entity):
52+
class Parameter(PipelineVariable, Entity):
5253
"""Pipeline parameter for workflow.
5354
5455
Attributes:
@@ -170,6 +171,13 @@ def __hash__(self):
170171
"""Hash function for parameter types"""
171172
return hash(tuple(self.to_request()))
172173

174+
def to_string(self) -> PipelineVariable:
175+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
176+
177+
As ParameterString is treated as String in runtime, no extra actions are needed.
178+
"""
179+
return self
180+
173181
def to_request(self) -> RequestType:
174182
"""Get the request structure for workflow service calls."""
175183
request_dict = super(ParameterString, self).to_request()

src/sagemaker/workflow/properties.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
"""The properties definitions for workflow."""
1414
from __future__ import absolute_import
1515

16+
from abc import ABCMeta
1617
from typing import Dict, Union, List
1718

1819
import attr
1920

2021
import botocore.loaders
2122

22-
from sagemaker.workflow.entities import Expression
23+
from sagemaker.workflow.entities import Expression, PipelineVariable
2324

2425

25-
class PropertiesMeta(type):
26+
class PropertiesMeta(ABCMeta):
2627
"""Load an internal shapes attribute from the botocore service model
2728
2829
for sagemaker and emr service.
@@ -44,7 +45,7 @@ def __new__(mcs, *args, **kwargs):
4445
return super().__new__(mcs, *args, **kwargs)
4546

4647

47-
class Properties(metaclass=PropertiesMeta):
48+
class Properties(PipelineVariable, metaclass=PropertiesMeta):
4849
"""Properties for use in workflow expressions."""
4950

5051
def __init__(

0 commit comments

Comments
 (0)