Skip to content

change: turn off Pipeline Parameter inheritance from python primitives #3086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def __init__(
if (
not self.sagemaker_session.local_mode
and output_path
and not is_pipeline_variable(output_path)
and output_path.startswith("file://")
):
raise RuntimeError("file:// output paths are only supported in Local Mode")
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
"""
if training_instance_type == "local" or distribution is None:
return
if is_pipeline_variable(training_instance_type):
# The training_instance_type is not available in compile time.
# Rather, it's given in Pipeline execution time
return

is_multi_gpu_instance = (
training_instance_type == "local_gpu"
Expand Down Expand Up @@ -485,6 +489,10 @@ def validate_smdistributed(
if "smdistributed" not in distribution:
# Distribution strategy other than smdistributed is selected
return
if is_pipeline_variable(instance_type):
# The instance_type is not available in compile time.
# Rather, it's given in Pipeline execution time
return

# distribution contains smdistributed
smdistributed = distribution["smdistributed"]
Expand Down
6 changes: 0 additions & 6 deletions src/sagemaker/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,8 @@
# language governing permissions and limitations under the License.
"""Defines Types etc. used in workflow."""
from __future__ import absolute_import
from typing import Union

from sagemaker.workflow.entities import Expression
from sagemaker.workflow.execution_variables import ExecutionVariable
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.properties import Properties

PipelineNonPrimitiveInputTypes = Union[ExecutionVariable, Expression, Parameter, Properties]


def is_pipeline_variable(var: object) -> bool:
Expand Down
28 changes: 14 additions & 14 deletions src/sagemaker/workflow/clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
from sagemaker.utils import name_from_base
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
from sagemaker.workflow.entities import RequestType
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import RequestType, PipelineVariable
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
Expand All @@ -59,7 +59,7 @@ class ClarifyCheckConfig(ABC):
data_config (DataConfig): Config of the input/output data.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
This field CANNOT be any type of the `PipelineVariable`.
monitoring_analysis_config_uri: (str): The uri of monitoring analysis config.
This field does not take input.
It will be generated once uploading the created analysis config file.
Expand All @@ -86,7 +86,7 @@ class DataBiasCheckConfig(ClarifyCheckConfig):
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
Defaults to computing all.
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
This field CANNOT be any type of the `PipelineVariable`.
""" # noqa E501

data_bias_config: BiasConfig = attr.ib()
Expand Down Expand Up @@ -115,7 +115,7 @@ class ModelBiasCheckConfig(ClarifyCheckConfig):
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
Defaults to computing all.
This field CANNOT be any of PipelineNonPrimitiveInputTypes.
This field CANNOT be any type of the `PipelineVariable`.
"""

data_bias_config: BiasConfig = attr.ib()
Expand All @@ -136,7 +136,7 @@ class ModelExplainabilityCheckConfig(ClarifyCheckConfig):
in the model output for the predicted scores to be explained (default: None).
This is not required if the model output is a single score. Alternatively,
an instance of ModelPredictedLabelConfig can be provided
but this field CANNOT be any of PipelineNonPrimitiveInputTypes.
but this field CANNOT be any type of the `PipelineVariable`.
"""

model_config: ModelConfig = attr.ib()
Expand All @@ -152,10 +152,10 @@ def __init__(
name: str,
clarify_check_config: ClarifyCheckConfig,
check_job_config: CheckJobConfig,
skip_check: Union[bool, PipelineNonPrimitiveInputTypes] = False,
register_new_baseline: Union[bool, PipelineNonPrimitiveInputTypes] = False,
model_package_group_name: Union[str, PipelineNonPrimitiveInputTypes] = None,
supplied_baseline_constraints: Union[str, PipelineNonPrimitiveInputTypes] = None,
skip_check: Union[bool, PipelineVariable] = False,
register_new_baseline: Union[bool, PipelineVariable] = False,
model_package_group_name: Union[str, PipelineVariable] = None,
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
display_name: str = None,
description: str = None,
cache_config: CacheConfig = None,
Expand All @@ -167,14 +167,14 @@ def __init__(
name (str): The name of the ClarifyCheckStep step.
clarify_check_config (ClarifyCheckConfig): A ClarifyCheckConfig instance.
check_job_config (CheckJobConfig): A CheckJobConfig instance.
skip_check (bool or PipelineNonPrimitiveInputTypes): Whether the check
skip_check (bool or PipelineVariable): Whether the check
should be skipped (default: False).
register_new_baseline (bool or PipelineNonPrimitiveInputTypes): Whether
register_new_baseline (bool or PipelineVariable): Whether
the new baseline should be registered (default: False).
model_package_group_name (str or PipelineNonPrimitiveInputTypes): The name of a
model_package_group_name (str or PipelineVariable): The name of a
registered model package group, among which the baseline will be fetched
from the latest approved model (default: None).
supplied_baseline_constraints (str or PipelineNonPrimitiveInputTypes): The S3 path
supplied_baseline_constraints (str or PipelineVariable): The S3 path
to the supplied constraints object representing the constraints JSON file
which will be used for drift to check (default: None).
display_name (str): The display name of the ClarifyCheckStep step (default: None).
Expand Down
38 changes: 1 addition & 37 deletions src/sagemaker/workflow/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import abc

from enum import EnumMeta
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union

PrimitiveType = Union[str, int, bool, float, None]
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
Expand Down Expand Up @@ -102,39 +102,3 @@ def to_string(self):
@abc.abstractmethod
def expr(self) -> RequestType:
"""Get the expression structure for workflow service calls."""

def startswith(
self,
prefix: Union[str, tuple], # pylint: disable=unused-argument
start: Optional[int] = None, # pylint: disable=unused-argument
end: Optional[int] = None, # pylint: disable=unused-argument
) -> bool:
"""Simulate the Python string's built-in method: startswith

Args:
prefix (str, tuple): The (tuple of) string to be checked.
start (int): To set the start index of the matching boundary (default: None).
end (int): To set the end index of the matching boundary (default: None).

Return:
bool: Always return False as Pipeline variables are parsed during execution runtime
"""
return False

def endswith(
self,
suffix: Union[str, tuple], # pylint: disable=unused-argument
start: Optional[int] = None, # pylint: disable=unused-argument
end: Optional[int] = None, # pylint: disable=unused-argument
) -> bool:
"""Simulate the Python string's built-in method: endswith

Args:
suffix (str, tuple): The (tuple of) string to be checked.
start (int): To set the start index of the matching boundary (default: None).
end (int): To set the end index of the matching boundary (default: None).

Return:
bool: Always return False as Pipeline variables are parsed during execution runtime
"""
return False
6 changes: 3 additions & 3 deletions src/sagemaker/workflow/fail_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from typing import List, Union, Optional

from sagemaker.workflow import PipelineNonPrimitiveInputTypes
from sagemaker.workflow.entities import (
RequestType,
PipelineVariable,
)
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum
Expand All @@ -29,7 +29,7 @@ class FailStep(Step):
def __init__(
self,
name: str,
error_message: Union[str, PipelineNonPrimitiveInputTypes] = None,
error_message: Union[str, PipelineVariable] = None,
display_name: str = None,
description: str = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
Expand All @@ -39,7 +39,7 @@ def __init__(
Args:
name (str): The name of the `FailStep`. A name is required and must be
unique within a pipeline.
error_message (str or PipelineNonPrimitiveInputTypes):
error_message (str or PipelineVariable):
An error message defined by the user.
Once the `FailStep` is reached, the execution fails and the
error message is set as the failure reason (default: None).
Expand Down
44 changes: 3 additions & 41 deletions src/sagemaker/workflow/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,29 +99,6 @@ def _expr(cls, name):
"""
return {"Get": f"Parameters.{name}"}

@classmethod
def _implicit_value(cls, value, python_type, args, kwargs):
"""Determine the implicit value from the arguments.

The implicit value of the instance should be the default_value if present.

Args:
value: The default implicit value.
python_type: The Python type the implicit value should be.
args: The list of positional arguments.
kwargs: The dict of keyword arguments.

Returns:
The implicit value that should be used.
"""
if len(args) == 2:
value = args[1] or value
elif kwargs:
value = kwargs.get("default_value", value)
cls._check_default_value_type(value, python_type)

return value

@classmethod
def _check_default_value_type(cls, value, python_type):
"""Check whether the default value is compatible with the parameter type.
Expand All @@ -143,14 +120,9 @@ def _check_default_value_type(cls, value, python_type):
ParameterBoolean = partial(Parameter, parameter_type=ParameterTypeEnum.BOOLEAN)


class ParameterString(Parameter, str):
class ParameterString(Parameter):
"""String parameter for pipelines."""

def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
"""Subclass str"""
val = cls._implicit_value("", str, args, kwargs)
return str.__new__(cls, val)

def __init__(self, name: str, default_value: str = None, enum_values: List[str] = None):
"""Create a pipeline string parameter.

Expand Down Expand Up @@ -186,14 +158,9 @@ def to_request(self) -> RequestType:
return request_dict


class ParameterInteger(Parameter, int):
class ParameterInteger(Parameter):
"""Integer parameter for pipelines."""

def __new__(cls, *args, **kwargs):
"""Subclass int"""
val = cls._implicit_value(0, int, args, kwargs)
return int.__new__(cls, val)

def __init__(self, name: str, default_value: int = None):
"""Create a pipeline integer parameter.

Expand All @@ -209,14 +176,9 @@ def __init__(self, name: str, default_value: int = None):
)


class ParameterFloat(Parameter, float):
class ParameterFloat(Parameter):
"""Float parameter for pipelines."""

def __new__(cls, *args, **kwargs):
"""Subclass float"""
val = cls._implicit_value(0.0, float, args, kwargs)
return float.__new__(cls, val)

def __init__(self, name: str, default_value: float = None):
"""Create a pipeline float parameter.

Expand Down
Loading