Skip to content

Commit 2cca3c6

Browse files
author
Dewen Qi
committed
change: turn off PipelineVariable inheritance from python primitives
1 parent c46a09d commit 2cca3c6

15 files changed

+81
-195
lines changed

src/sagemaker/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def __init__(
483483
if (
484484
not self.sagemaker_session.local_mode
485485
and output_path
486+
and not is_pipeline_variable(output_path)
486487
and output_path.startswith("file://")
487488
):
488489
raise RuntimeError("file:// output paths are only supported in Local Mode")

src/sagemaker/fw_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import sagemaker.utils
2727

2828
from sagemaker.deprecations import renamed_warning
29+
from sagemaker.workflow import is_pipeline_variable
2930

3031
logger = logging.getLogger(__name__)
3132

@@ -421,6 +422,10 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
421422
"""
422423
if training_instance_type == "local" or distribution is None:
423424
return
425+
if is_pipeline_variable(training_instance_type):
426+
# The training_instance_type is not available in compile time.
427+
# Rather, it's given in Pipeline execution time
428+
return
424429

425430
is_multi_gpu_instance = (
426431
training_instance_type == "local_gpu"
@@ -478,6 +483,10 @@ def validate_smdistributed(
478483
if "smdistributed" not in distribution:
479484
# Distribution strategy other than smdistributed is selected
480485
return
486+
if is_pipeline_variable(instance_type):
487+
# The instance_type is not available in compile time.
488+
# Rather, it's given in Pipeline execution time
489+
return
481490

482491
# distribution contains smdistributed
483492
smdistributed = distribution["smdistributed"]

src/sagemaker/workflow/__init__.py

-6
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
"""Defines Types etc. used in workflow."""
1414
from __future__ import absolute_import
15-
from typing import Union
1615

1716
from sagemaker.workflow.entities import Expression
18-
from sagemaker.workflow.execution_variables import ExecutionVariable
19-
from sagemaker.workflow.parameters import Parameter
20-
from sagemaker.workflow.properties import Properties
21-
22-
PipelineNonPrimitiveInputTypes = Union[ExecutionVariable, Expression, Parameter, Properties]
2317

2418

2519
def is_pipeline_variable(var: object) -> bool:

src/sagemaker/workflow/clarify_check_step.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
3838
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
3939
from sagemaker.utils import name_from_base
40-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
41-
from sagemaker.workflow.entities import RequestType
40+
from sagemaker.workflow import is_pipeline_variable
41+
from sagemaker.workflow.entities import RequestType, PipelineVariable
4242
from sagemaker.workflow.properties import Properties
4343
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
4444
from sagemaker.workflow.check_job_config import CheckJobConfig
@@ -58,7 +58,7 @@ class ClarifyCheckConfig(ABC):
5858
data_config (DataConfig): Config of the input/output data.
5959
kms_key (str): The ARN of the KMS key that is used to encrypt the
6060
user code file (default: None).
61-
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
61+
This field CANNOT be any type of the `PipelineVariable`.
6262
monitoring_analysis_config_uri: (str): The uri of monitoring analysis config.
6363
This field does not take input.
6464
It will be generated once uploading the created analysis config file.
@@ -85,7 +85,7 @@ class DataBiasCheckConfig(ClarifyCheckConfig):
8585
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
8686
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
8787
Defaults to computing all.
88-
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
88+
This field CANNOT be any type of the `PipelineVariable`.
8989
""" # noqa E501
9090

9191
data_bias_config: BiasConfig = attr.ib()
@@ -114,7 +114,7 @@ class ModelBiasCheckConfig(ClarifyCheckConfig):
114114
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
115115
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
116116
Defaults to computing all.
117-
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
117+
This field CANNOT be any type of the `PipelineVariable`.
118118
"""
119119

120120
data_bias_config: BiasConfig = attr.ib()
@@ -135,7 +135,7 @@ class ModelExplainabilityCheckConfig(ClarifyCheckConfig):
135135
in the model output for the predicted scores to be explained (default: None).
136136
This is not required if the model output is a single score. Alternatively,
137137
an instance of ModelPredictedLabelConfig can be provided
138-
but this field CANNOT be any of PipelineNonPrimitiveInputTypes.
138+
but this field CANNOT be any type of the `PipelineVariable`.
139139
"""
140140

141141
model_config: ModelConfig = attr.ib()
@@ -151,10 +151,10 @@ def __init__(
151151
name: str,
152152
clarify_check_config: ClarifyCheckConfig,
153153
check_job_config: CheckJobConfig,
154-
skip_check: Union[bool, PipelineNonPrimitiveInputTypes] = False,
155-
register_new_baseline: Union[bool, PipelineNonPrimitiveInputTypes] = False,
156-
model_package_group_name: Union[str, PipelineNonPrimitiveInputTypes] = None,
157-
supplied_baseline_constraints: Union[str, PipelineNonPrimitiveInputTypes] = None,
154+
skip_check: Union[bool, PipelineVariable] = False,
155+
register_new_baseline: Union[bool, PipelineVariable] = False,
156+
model_package_group_name: Union[str, PipelineVariable] = None,
157+
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
158158
display_name: str = None,
159159
description: str = None,
160160
cache_config: CacheConfig = None,
@@ -166,14 +166,14 @@ def __init__(
166166
name (str): The name of the ClarifyCheckStep step.
167167
clarify_check_config (ClarifyCheckConfig): A ClarifyCheckConfig instance.
168168
check_job_config (CheckJobConfig): A CheckJobConfig instance.
169-
skip_check (bool or PipelineNonPrimitiveInputTypes): Whether the check
169+
skip_check (bool or PipelineVariable): Whether the check
170170
should be skipped (default: False).
171-
register_new_baseline (bool or PipelineNonPrimitiveInputTypes): Whether
171+
register_new_baseline (bool or PipelineVariable): Whether
172172
the new baseline should be registered (default: False).
173-
model_package_group_name (str or PipelineNonPrimitiveInputTypes): The name of a
173+
model_package_group_name (str or PipelineVariable): The name of a
174174
registered model package group, among which the baseline will be fetched
175175
from the latest approved model (default: None).
176-
supplied_baseline_constraints (str or PipelineNonPrimitiveInputTypes): The S3 path
176+
supplied_baseline_constraints (str or PipelineVariable): The S3 path
177177
to the supplied constraints object representing the constraints JSON file
178178
which will be used for drift to check (default: None).
179179
display_name (str): The display name of the ClarifyCheckStep step (default: None).

src/sagemaker/workflow/entities.py

+1-37
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import abc
1717

1818
from enum import EnumMeta
19-
from typing import Any, Dict, List, Union, Optional
19+
from typing import Any, Dict, List, Union
2020

2121
PrimitiveType = Union[str, int, bool, float, None]
2222
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
@@ -98,39 +98,3 @@ def to_string(self):
9898
@abc.abstractmethod
9999
def expr(self) -> RequestType:
100100
"""Get the expression structure for workflow service calls."""
101-
102-
def startswith(
103-
self,
104-
prefix: Union[str, tuple], # pylint: disable=unused-argument
105-
start: Optional[int] = None, # pylint: disable=unused-argument
106-
end: Optional[int] = None, # pylint: disable=unused-argument
107-
) -> bool:
108-
"""Simulate the Python string's built-in method: startswith
109-
110-
Args:
111-
prefix (str, tuple): The (tuple of) string to be checked.
112-
start (int): To set the start index of the matching boundary (default: None).
113-
end (int): To set the end index of the matching boundary (default: None).
114-
115-
Return:
116-
bool: Always return False as Pipeline variables are parsed during execution runtime
117-
"""
118-
return False
119-
120-
def endswith(
121-
self,
122-
suffix: Union[str, tuple], # pylint: disable=unused-argument
123-
start: Optional[int] = None, # pylint: disable=unused-argument
124-
end: Optional[int] = None, # pylint: disable=unused-argument
125-
) -> bool:
126-
"""Simulate the Python string's built-in method: endswith
127-
128-
Args:
129-
suffix (str, tuple): The (tuple of) string to be checked.
130-
start (int): To set the start index of the matching boundary (default: None).
131-
end (int): To set the end index of the matching boundary (default: None).
132-
133-
Return:
134-
bool: Always return False as Pipeline variables are parsed during execution runtime
135-
"""
136-
return False

src/sagemaker/workflow/fail_step.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from typing import List, Union
1717

18-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes
1918
from sagemaker.workflow.entities import (
2019
RequestType,
20+
PipelineVariable,
2121
)
2222
from sagemaker.workflow.steps import Step, StepTypeEnum
2323

@@ -28,7 +28,7 @@ class FailStep(Step):
2828
def __init__(
2929
self,
3030
name: str,
31-
error_message: Union[str, PipelineNonPrimitiveInputTypes] = None,
31+
error_message: Union[str, PipelineVariable] = None,
3232
display_name: str = None,
3333
description: str = None,
3434
depends_on: Union[List[str], List[Step]] = None,
@@ -38,7 +38,7 @@ def __init__(
3838
Args:
3939
name (str): The name of the `FailStep`. A name is required and must be
4040
unique within a pipeline.
41-
error_message (str or PipelineNonPrimitiveInputTypes):
41+
error_message (str or PipelineVariable):
4242
An error message defined by the user.
4343
Once the `FailStep` is reached, the execution fails and the
4444
error message is set as the failure reason (default: None).

src/sagemaker/workflow/parameters.py

+3-41
Original file line numberDiff line numberDiff line change
@@ -99,29 +99,6 @@ def _expr(cls, name):
9999
"""
100100
return {"Get": f"Parameters.{name}"}
101101

102-
@classmethod
103-
def _implicit_value(cls, value, python_type, args, kwargs):
104-
"""Determine the implicit value from the arguments.
105-
106-
The implicit value of the instance should be the default_value if present.
107-
108-
Args:
109-
value: The default implicit value.
110-
python_type: The Python type the implicit value should be.
111-
args: The list of positional arguments.
112-
kwargs: The dict of keyword arguments.
113-
114-
Returns:
115-
The implicit value that should be used.
116-
"""
117-
if len(args) == 2:
118-
value = args[1] or value
119-
elif kwargs:
120-
value = kwargs.get("default_value", value)
121-
cls._check_default_value_type(value, python_type)
122-
123-
return value
124-
125102
@classmethod
126103
def _check_default_value_type(cls, value, python_type):
127104
"""Check whether the default value is compatible with the parameter type.
@@ -143,14 +120,9 @@ def _check_default_value_type(cls, value, python_type):
143120
ParameterBoolean = partial(Parameter, parameter_type=ParameterTypeEnum.BOOLEAN)
144121

145122

146-
class ParameterString(Parameter, str):
123+
class ParameterString(Parameter):
147124
"""String parameter for pipelines."""
148125

149-
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
150-
"""Subclass str"""
151-
val = cls._implicit_value("", str, args, kwargs)
152-
return str.__new__(cls, val)
153-
154126
def __init__(self, name: str, default_value: str = None, enum_values: List[str] = None):
155127
"""Create a pipeline string parameter.
156128
@@ -186,14 +158,9 @@ def to_request(self) -> RequestType:
186158
return request_dict
187159

188160

189-
class ParameterInteger(Parameter, int):
161+
class ParameterInteger(Parameter):
190162
"""Integer parameter for pipelines."""
191163

192-
def __new__(cls, *args, **kwargs):
193-
"""Subclass int"""
194-
val = cls._implicit_value(0, int, args, kwargs)
195-
return int.__new__(cls, val)
196-
197164
def __init__(self, name: str, default_value: int = None):
198165
"""Create a pipeline integer parameter.
199166
@@ -209,14 +176,9 @@ def __init__(self, name: str, default_value: int = None):
209176
)
210177

211178

212-
class ParameterFloat(Parameter, float):
179+
class ParameterFloat(Parameter):
213180
"""Float parameter for pipelines."""
214181

215-
def __new__(cls, *args, **kwargs):
216-
"""Subclass float"""
217-
val = cls._implicit_value(0.0, float, args, kwargs)
218-
return float.__new__(cls, val)
219-
220182
def __init__(self, name: str, default_value: float = None):
221183
"""Create a pipeline float parameter.
222184

0 commit comments

Comments
 (0)