Skip to content

fix: Add back the Fix for Pipeline variables related customer issues #3043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
get_config_value,
name_from_base,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow import is_pipeline_variable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -600,7 +600,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
current_hyperparameters = hyperparameters
if current_hyperparameters is not None:
hyperparameters = {
str(k): (v.to_string() if isinstance(v, PipelineVariable) else json.dumps(v))
str(k): (v.to_string() if is_pipeline_variable(v) else json.dumps(v))
for (k, v) in current_hyperparameters.items()
}
return hyperparameters
Expand Down Expand Up @@ -1811,7 +1811,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
current_hyperparameters = estimator.hyperparameters()
if current_hyperparameters is not None:
hyperparameters = {
str(k): (v.to_string() if isinstance(v, PipelineVariable) else str(v))
str(k): (v.to_string() if is_pipeline_variable(v) else str(v))
for (k, v) in current_hyperparameters.items()
}

Expand Down Expand Up @@ -1879,7 +1879,9 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
if estimator.use_spot_instances:
if local_mode:
raise ValueError("Spot training is not supported in local mode.")
train_args["use_spot_instances"] = True
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
train_args["use_spot_instances"] = estimator.use_spot_instances

if estimator.checkpoint_s3_uri:
if local_mode:
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sagemaker.utils import unique_name_from_base
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.workflow import is_pipeline_variable

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -449,7 +450,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
)

if repack and self.model_data is not None and self.entry_point is not None:
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
if is_pipeline_variable(self.model_data):
# model is not yet there, defer repacking to later during pipeline execution
return

Expand Down
10 changes: 4 additions & 6 deletions src/sagemaker/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import json

from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow import is_pipeline_variable


class ParameterRange(object):
Expand Down Expand Up @@ -72,10 +72,10 @@ def as_tuning_range(self, name):
return {
"Name": name,
"MinValue": str(self.min_value)
if not isinstance(self.min_value, PipelineVariable)
if not is_pipeline_variable(self.min_value)
else self.min_value.to_string(),
"MaxValue": str(self.max_value)
if not isinstance(self.max_value, PipelineVariable)
if not is_pipeline_variable(self.max_value)
else self.max_value.to_string(),
"ScalingType": self.scaling_type,
}
Expand Down Expand Up @@ -110,9 +110,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
This input will be converted into a list of strings.
"""
values = values if isinstance(values, list) else [values]
self.values = [
str(v) if not isinstance(v, PipelineVariable) else v.to_string() for v in values
]
self.values = [str(v) if not is_pipeline_variable(v) else v.to_string() for v in values]

def as_tuning_range(self, name):
"""Represent the parameter range as a dictionary.
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
from sagemaker.local import LocalSession
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
from sagemaker.session import Session
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.entities import Expression
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.s3 import S3Uploader


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -233,6 +233,12 @@ def _normalize_args(
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
if code and is_pipeline_variable(code):
raise ValueError(
"code argument has to be a valid S3 URI or local file path "
+ "rather than a pipeline variable"
)

self._current_job_name = self._generate_current_job_name(job_name=job_name)

inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)
Expand Down
12 changes: 9 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ def _get_train_request( # noqa: C901
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic

if use_spot_instances:
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
train_request["EnableManagedSpotTraining"] = use_spot_instances

if checkpoint_s3_uri:
Expand Down Expand Up @@ -2340,13 +2342,17 @@ def _map_training_config(
training_job_definition["VpcConfig"] = vpc_config

if enable_network_isolation:
training_job_definition["EnableNetworkIsolation"] = True
training_job_definition["EnableNetworkIsolation"] = enable_network_isolation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Isnt this redundant, if F will never come here?

Copy link
Member Author

@qidewenwhen qidewenwhen Apr 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was caught from a customer reported issue.
enable_network_isolation can be a Pipeline variable like ParameterBoolean, which is not parsed in compile time (SDK stage).
However, in such case, it will go into this line and hard code a True to it, which means the ParameterBoolean object will be lost in the Pipeline definition and accordingly users are not able to update this parameter in runtime.


if encrypt_inter_container_traffic:
training_job_definition["EnableInterContainerTrafficEncryption"] = True
training_job_definition[
"EnableInterContainerTrafficEncryption"
] = encrypt_inter_container_traffic

if use_spot_instances:
training_job_definition["EnableManagedSpotTraining"] = True
# use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
training_job_definition["EnableManagedSpotTraining"] = use_spot_instances

if checkpoint_s3_uri:
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sagemaker.deprecations import removed_kwargs
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.workflow import is_pipeline_variable


class TensorFlowPredictor(Predictor):
Expand Down Expand Up @@ -335,7 +336,9 @@ def prepare_container_def(
)
env = self._get_container_env()

if self.entry_point:
# If self.model_data is pipeline variable, model is not yet there.
# So defer repacking to later during pipeline execution
if self.entry_point and not is_pipeline_variable(self.model_data):
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
self.key_prefix, self.name, image_uri
)
Expand Down
19 changes: 2 additions & 17 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,10 @@
IntegerParameter,
ParameterRange,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import Parameter as PipelineParameter
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
from sagemaker.workflow.functions import Join as PipelineJoin

from sagemaker.session import Session
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
from sagemaker.workflow import is_pipeline_variable

AMAZON_ESTIMATOR_MODULE = "sagemaker"
AMAZON_ESTIMATOR_CLS_NAMES = {
Expand All @@ -64,18 +61,6 @@
logger = logging.getLogger(__name__)


def is_pipeline_parameters(value):
"""Determine if a value is a pipeline parameter or function representation

Args:
value (float or int): The value to be verified.

Returns:
bool: True if it is, False otherwise.
"""
return isinstance(value, (PipelineParameter, PipelineJsonGet, PipelineJoin))


class WarmStartTypes(Enum):
"""Warm Start Configuration type.

Expand Down Expand Up @@ -377,7 +362,7 @@ def _prepare_static_hyperparameters(
"""Prepare static hyperparameters for one estimator before tuning."""
# Remove any hyperparameter that will be tuned
static_hyperparameters = {
str(k): str(v) if not isinstance(v, PipelineVariable) else v.to_string()
str(k): str(v) if not is_pipeline_variable(v) else v.to_string()
for (k, v) in estimator.hyperparameters().items()
}
for hyperparameter_name in hyperparameter_ranges.keys():
Expand Down
15 changes: 15 additions & 0 deletions src/sagemaker/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,18 @@
from sagemaker.workflow.properties import Properties

PipelineNonPrimitiveInputTypes = Union[ExecutionVariable, Expression, Parameter, Properties]


def is_pipeline_variable(var: object) -> bool:
"""Check if the variable is a pipeline variable

Args:
var (object): The variable to be verified.
Returns:
bool: True if it is, False otherwise.
"""

# Currently Expression is on top of all kinds of pipeline variables
# as well as PipelineExperimentConfigProperty and PropertyFile
# TODO: We should deprecate the Expression and replace it with PipelineVariable
return isinstance(var, Expression)
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):

Args:
inference_script (str): The path to the custom entry point.
model_archive (str): The name of the model TAR archive.
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
dependencies (str): A space-delimited string of paths to custom dependencies.
source_dir (str): The path to a custom source directory.
"""

# the data directory contains a model archive generated by a previous training job
data_directory = "/opt/ml/input/data/training"
model_path = os.path.join(data_directory, model_archive)
model_path = os.path.join(data_directory, model_archive.split("/")[-1])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _repack_model.py is a script which will be uploaded and executed during Pipeline execution runtime (backend), which means all pipeline variables are correctly parsed to proper type like str at that time.
So it's safe to do the str split in this script.
A corresponding change is in _utils.py below (link). As we move the str split into the _repack_model.py, we can clean up those condition check lines in _utils.py.

# create a temporary directory
with tempfile.TemporaryDirectory() as tmp:
Expand Down
10 changes: 2 additions & 8 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,6 @@ def __init__(
self._model_data = model_data
self.sagemaker_session = sagemaker_session
self.role = role
if isinstance(model_data, Properties):
self._model_prefix = model_data
self._model_archive = "model.tar.gz"
else:
self._model_prefix = "/".join(self._model_data.split("/")[:-1])
self._model_archive = self._model_data.split("/")[-1]
self._entry_point = entry_point
self._entry_point_basename = os.path.basename(self._entry_point)
self._source_dir = source_dir
Expand All @@ -161,7 +155,7 @@ def __init__(
role=self.role,
hyperparameters={
"inference_script": self._entry_point_basename,
"model_archive": self._model_archive,
"model_archive": self._model_data,
"dependencies": dependencies_hyperparameter,
"source_dir": self._source_dir,
},
Expand All @@ -170,7 +164,7 @@ def __init__(
**kwargs,
)
repacker.disable_profiler = True
inputs = TrainingInput(self._model_prefix)
inputs = TrainingInput(self._model_data)

# super!
super(_RepackModelStep, self).__init__(
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
train_config["VpcConfig"] = job_config["vpc_config"]

if estimator.use_spot_instances:
train_config["EnableManagedSpotTraining"] = True
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
# which is parsed during the Pipeline execution runtime
train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances

if estimator.hyperparameters() is not None:
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
Expand Down
15 changes: 6 additions & 9 deletions src/sagemaker/workflow/clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
from sagemaker.utils import name_from_base
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
from sagemaker.workflow.entities import RequestType, Expression
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
from sagemaker.workflow.entities import RequestType
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.workflow.check_job_config import CheckJobConfig
Expand Down Expand Up @@ -193,18 +193,15 @@ def __init__(
+ "DataBiasCheckConfig, ModelBiasCheckConfig or ModelExplainabilityCheckConfig"
)

if isinstance(
clarify_check_config.data_config.s3_analysis_config_output_path,
(ExecutionVariable, Expression, Parameter, Properties),
):
if is_pipeline_variable(clarify_check_config.data_config.s3_analysis_config_output_path):
raise RuntimeError(
"s3_analysis_config_output_path cannot be of type "
+ "ExecutionVariable/Expression/Parameter/Properties"
)

if not clarify_check_config.data_config.s3_analysis_config_output_path and isinstance(
clarify_check_config.data_config.s3_output_path,
(ExecutionVariable, Expression, Parameter, Properties),
if (
not clarify_check_config.data_config.s3_analysis_config_output_path
and is_pipeline_variable(clarify_check_config.data_config.s3_output_path)
):
raise RuntimeError(
"`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter"
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.utilities import list_to_request
from sagemaker.workflow.entities import (
Expression,
RequestType,
PipelineVariable,
)
from sagemaker.workflow.properties import (
Properties,
Expand Down Expand Up @@ -95,7 +95,7 @@ def properties(self):


@attr.s
class JsonGet(Expression): # pragma: no cover
class JsonGet(PipelineVariable): # pragma: no cover
"""Get JSON properties from PropertyFiles.

Attributes:
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/workflow/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import attr

from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import (
DefaultEnumMeta,
Entity,
Expand Down Expand Up @@ -261,6 +262,6 @@ def primitive_or_expr(
Returns:
Either the expression of the value or the primitive value.
"""
if isinstance(value, (ExecutionVariable, Expression, Parameter, Properties)):
if is_pipeline_variable(value):
return value.expr
return value
16 changes: 10 additions & 6 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,20 @@ def interpolate(
Args:
request_obj (RequestType): The request dict.
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
lambda_output_to_step_map (Dict[str, str]): A dict of output name -> step name.

Returns:
RequestType: The request dict with Parameter values replaced by their expression.
"""
request_obj_copy = deepcopy(request_obj)
return _interpolate(
request_obj_copy,
callback_output_to_step_map=callback_output_to_step_map,
lambda_output_to_step_map=lambda_output_to_step_map,
)
try:
request_obj_copy = deepcopy(request_obj)
return _interpolate(
request_obj_copy,
callback_output_to_step_map=callback_output_to_step_map,
lambda_output_to_step_map=lambda_output_to_step_map,
)
except TypeError as type_err:
raise TypeError("Not able to interpolate Pipeline definition: %s" % type_err)


def _interpolate(
Expand Down
Loading