Skip to content

Commit c9e517c

Browse files
Dewen Qiqidewenwhen
Dewen Qi
authored andcommitted
change: turn off PipelineVariable inheritance from python primitives
1 parent 19b5560 commit c9e517c

17 files changed

+83
-198
lines changed

src/sagemaker/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def __init__(
483483
if (
484484
not self.sagemaker_session.local_mode
485485
and output_path
486+
and not is_pipeline_variable(output_path)
486487
and output_path.startswith("file://")
487488
):
488489
raise RuntimeError("file:// output paths are only supported in Local Mode")

src/sagemaker/fw_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,10 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
428428
"""
429429
if training_instance_type == "local" or distribution is None:
430430
return
431+
if is_pipeline_variable(training_instance_type):
432+
# The training_instance_type is not available in compile time.
433+
# Rather, it's given in Pipeline execution time
434+
return
431435

432436
is_multi_gpu_instance = (
433437
training_instance_type == "local_gpu"
@@ -485,6 +489,10 @@ def validate_smdistributed(
485489
if "smdistributed" not in distribution:
486490
# Distribution strategy other than smdistributed is selected
487491
return
492+
if is_pipeline_variable(instance_type):
493+
# The instance_type is not available in compile time.
494+
# Rather, it's given in Pipeline execution time
495+
return
488496

489497
# distribution contains smdistributed
490498
smdistributed = distribution["smdistributed"]

src/sagemaker/workflow/__init__.py

-6
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
"""Defines Types etc. used in workflow."""
1414
from __future__ import absolute_import
15-
from typing import Union
1615

1716
from sagemaker.workflow.entities import Expression
18-
from sagemaker.workflow.execution_variables import ExecutionVariable
19-
from sagemaker.workflow.parameters import Parameter
20-
from sagemaker.workflow.properties import Properties
21-
22-
PipelineNonPrimitiveInputTypes = Union[ExecutionVariable, Expression, Parameter, Properties]
2317

2418

2519
def is_pipeline_variable(var: object) -> bool:

src/sagemaker/workflow/clarify_check_step.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
3838
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
3939
from sagemaker.utils import name_from_base
40-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
41-
from sagemaker.workflow.entities import RequestType
40+
from sagemaker.workflow import is_pipeline_variable
41+
from sagemaker.workflow.entities import RequestType, PipelineVariable
4242
from sagemaker.workflow.properties import Properties
4343
from sagemaker.workflow.step_collections import StepCollection
4444
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
@@ -59,7 +59,7 @@ class ClarifyCheckConfig(ABC):
5959
data_config (DataConfig): Config of the input/output data.
6060
kms_key (str): The ARN of the KMS key that is used to encrypt the
6161
user code file (default: None).
62-
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
62+
This field CANNOT be any type of the `PipelineVariable`.
6363
monitoring_analysis_config_uri: (str): The uri of monitoring analysis config.
6464
This field does not take input.
6565
It will be generated once uploading the created analysis config file.
@@ -86,7 +86,7 @@ class DataBiasCheckConfig(ClarifyCheckConfig):
8686
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
8787
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
8888
Defaults to computing all.
89-
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
89+
This field CANNOT be any type of the `PipelineVariable`.
9090
""" # noqa E501
9191

9292
data_bias_config: BiasConfig = attr.ib()
@@ -115,7 +115,7 @@ class ModelBiasCheckConfig(ClarifyCheckConfig):
115115
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
116116
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
117117
Defaults to computing all.
118-
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
118+
This field CANNOT be any type of the `PipelineVariable`.
119119
"""
120120

121121
data_bias_config: BiasConfig = attr.ib()
@@ -136,7 +136,7 @@ class ModelExplainabilityCheckConfig(ClarifyCheckConfig):
136136
in the model output for the predicted scores to be explained (default: None).
137137
This is not required if the model output is a single score. Alternatively,
138138
an instance of ModelPredictedLabelConfig can be provided
139-
but this field CANNOT be any of PipelineNonPrimitiveInputTypes.
139+
but this field CANNOT be any type of the `PipelineVariable`.
140140
"""
141141

142142
model_config: ModelConfig = attr.ib()
@@ -152,10 +152,10 @@ def __init__(
152152
name: str,
153153
clarify_check_config: ClarifyCheckConfig,
154154
check_job_config: CheckJobConfig,
155-
skip_check: Union[bool, PipelineNonPrimitiveInputTypes] = False,
156-
register_new_baseline: Union[bool, PipelineNonPrimitiveInputTypes] = False,
157-
model_package_group_name: Union[str, PipelineNonPrimitiveInputTypes] = None,
158-
supplied_baseline_constraints: Union[str, PipelineNonPrimitiveInputTypes] = None,
155+
skip_check: Union[bool, PipelineVariable] = False,
156+
register_new_baseline: Union[bool, PipelineVariable] = False,
157+
model_package_group_name: Union[str, PipelineVariable] = None,
158+
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
159159
display_name: str = None,
160160
description: str = None,
161161
cache_config: CacheConfig = None,
@@ -167,14 +167,14 @@ def __init__(
167167
name (str): The name of the ClarifyCheckStep step.
168168
clarify_check_config (ClarifyCheckConfig): A ClarifyCheckConfig instance.
169169
check_job_config (CheckJobConfig): A CheckJobConfig instance.
170-
skip_check (bool or PipelineNonPrimitiveInputTypes): Whether the check
170+
skip_check (bool or PipelineVariable): Whether the check
171171
should be skipped (default: False).
172-
register_new_baseline (bool or PipelineNonPrimitiveInputTypes): Whether
172+
register_new_baseline (bool or PipelineVariable): Whether
173173
the new baseline should be registered (default: False).
174-
model_package_group_name (str or PipelineNonPrimitiveInputTypes): The name of a
174+
model_package_group_name (str or PipelineVariable): The name of a
175175
registered model package group, among which the baseline will be fetched
176176
from the latest approved model (default: None).
177-
supplied_baseline_constraints (str or PipelineNonPrimitiveInputTypes): The S3 path
177+
supplied_baseline_constraints (str or PipelineVariable): The S3 path
178178
to the supplied constraints object representing the constraints JSON file
179179
which will be used for drift to check (default: None).
180180
display_name (str): The display name of the ClarifyCheckStep step (default: None).

src/sagemaker/workflow/entities.py

+1-37
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import abc
1717

1818
from enum import EnumMeta
19-
from typing import Any, Dict, List, Union, Optional
19+
from typing import Any, Dict, List, Union
2020

2121
PrimitiveType = Union[str, int, bool, float, None]
2222
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
@@ -102,39 +102,3 @@ def to_string(self):
102102
@abc.abstractmethod
103103
def expr(self) -> RequestType:
104104
"""Get the expression structure for workflow service calls."""
105-
106-
def startswith(
107-
self,
108-
prefix: Union[str, tuple], # pylint: disable=unused-argument
109-
start: Optional[int] = None, # pylint: disable=unused-argument
110-
end: Optional[int] = None, # pylint: disable=unused-argument
111-
) -> bool:
112-
"""Simulate the Python string's built-in method: startswith
113-
114-
Args:
115-
prefix (str, tuple): The (tuple of) string to be checked.
116-
start (int): To set the start index of the matching boundary (default: None).
117-
end (int): To set the end index of the matching boundary (default: None).
118-
119-
Return:
120-
bool: Always return False as Pipeline variables are parsed during execution runtime
121-
"""
122-
return False
123-
124-
def endswith(
125-
self,
126-
suffix: Union[str, tuple], # pylint: disable=unused-argument
127-
start: Optional[int] = None, # pylint: disable=unused-argument
128-
end: Optional[int] = None, # pylint: disable=unused-argument
129-
) -> bool:
130-
"""Simulate the Python string's built-in method: endswith
131-
132-
Args:
133-
suffix (str, tuple): The (tuple of) string to be checked.
134-
start (int): To set the start index of the matching boundary (default: None).
135-
end (int): To set the end index of the matching boundary (default: None).
136-
137-
Return:
138-
bool: Always return False as Pipeline variables are parsed during execution runtime
139-
"""
140-
return False

src/sagemaker/workflow/fail_step.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from typing import List, Union, Optional
1717

18-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes
1918
from sagemaker.workflow.entities import (
2019
RequestType,
20+
PipelineVariable,
2121
)
2222
from sagemaker.workflow.step_collections import StepCollection
2323
from sagemaker.workflow.steps import Step, StepTypeEnum
@@ -29,7 +29,7 @@ class FailStep(Step):
2929
def __init__(
3030
self,
3131
name: str,
32-
error_message: Union[str, PipelineNonPrimitiveInputTypes] = None,
32+
error_message: Union[str, PipelineVariable] = None,
3333
display_name: str = None,
3434
description: str = None,
3535
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
@@ -39,7 +39,7 @@ def __init__(
3939
Args:
4040
name (str): The name of the `FailStep`. A name is required and must be
4141
unique within a pipeline.
42-
error_message (str or PipelineNonPrimitiveInputTypes):
42+
error_message (str or PipelineVariable):
4343
An error message defined by the user.
4444
Once the `FailStep` is reached, the execution fails and the
4545
error message is set as the failure reason (default: None).

src/sagemaker/workflow/parameters.py

+3-41
Original file line numberDiff line numberDiff line change
@@ -99,29 +99,6 @@ def _expr(cls, name):
9999
"""
100100
return {"Get": f"Parameters.{name}"}
101101

102-
@classmethod
103-
def _implicit_value(cls, value, python_type, args, kwargs):
104-
"""Determine the implicit value from the arguments.
105-
106-
The implicit value of the instance should be the default_value if present.
107-
108-
Args:
109-
value: The default implicit value.
110-
python_type: The Python type the implicit value should be.
111-
args: The list of positional arguments.
112-
kwargs: The dict of keyword arguments.
113-
114-
Returns:
115-
The implicit value that should be used.
116-
"""
117-
if len(args) == 2:
118-
value = args[1] or value
119-
elif kwargs:
120-
value = kwargs.get("default_value", value)
121-
cls._check_default_value_type(value, python_type)
122-
123-
return value
124-
125102
@classmethod
126103
def _check_default_value_type(cls, value, python_type):
127104
"""Check whether the default value is compatible with the parameter type.
@@ -143,14 +120,9 @@ def _check_default_value_type(cls, value, python_type):
143120
ParameterBoolean = partial(Parameter, parameter_type=ParameterTypeEnum.BOOLEAN)
144121

145122

146-
class ParameterString(Parameter, str):
123+
class ParameterString(Parameter):
147124
"""String parameter for pipelines."""
148125

149-
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
150-
"""Subclass str"""
151-
val = cls._implicit_value("", str, args, kwargs)
152-
return str.__new__(cls, val)
153-
154126
def __init__(self, name: str, default_value: str = None, enum_values: List[str] = None):
155127
"""Create a pipeline string parameter.
156128
@@ -186,14 +158,9 @@ def to_request(self) -> RequestType:
186158
return request_dict
187159

188160

189-
class ParameterInteger(Parameter, int):
161+
class ParameterInteger(Parameter):
190162
"""Integer parameter for pipelines."""
191163

192-
def __new__(cls, *args, **kwargs):
193-
"""Subclass int"""
194-
val = cls._implicit_value(0, int, args, kwargs)
195-
return int.__new__(cls, val)
196-
197164
def __init__(self, name: str, default_value: int = None):
198165
"""Create a pipeline integer parameter.
199166
@@ -209,14 +176,9 @@ def __init__(self, name: str, default_value: int = None):
209176
)
210177

211178

212-
class ParameterFloat(Parameter, float):
179+
class ParameterFloat(Parameter):
213180
"""Float parameter for pipelines."""
214181

215-
def __new__(cls, *args, **kwargs):
216-
"""Subclass float"""
217-
val = cls._implicit_value(0.0, float, args, kwargs)
218-
return float.__new__(cls, val)
219-
220182
def __init__(self, name: str, default_value: float = None):
221183
"""Create a pipeline float parameter.
222184

0 commit comments

Comments
 (0)