Skip to content

Commit 6ffe925

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
fix: Add back the Fix for Pipeline variables related customer issues (#3043)
* Revert "Revert "fix: Fix Pipeline variables related customer issues (#2959)" (#3041)" This reverts commit 2782f8c. * fix: Include deprecated JsonGet into PipelineVariable Co-authored-by: Dewen Qi <[email protected]>
1 parent 4fc7f2c commit 6ffe925

25 files changed

+472
-93
lines changed

src/sagemaker/estimator.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
get_config_value,
7575
name_from_base,
7676
)
77-
from sagemaker.workflow.entities import PipelineVariable
77+
from sagemaker.workflow import is_pipeline_variable
7878

7979
logger = logging.getLogger(__name__)
8080

@@ -600,7 +600,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
600600
current_hyperparameters = hyperparameters
601601
if current_hyperparameters is not None:
602602
hyperparameters = {
603-
str(k): (v.to_string() if isinstance(v, PipelineVariable) else json.dumps(v))
603+
str(k): (v.to_string() if is_pipeline_variable(v) else json.dumps(v))
604604
for (k, v) in current_hyperparameters.items()
605605
}
606606
return hyperparameters
@@ -1811,7 +1811,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
18111811
current_hyperparameters = estimator.hyperparameters()
18121812
if current_hyperparameters is not None:
18131813
hyperparameters = {
1814-
str(k): (v.to_string() if isinstance(v, PipelineVariable) else str(v))
1814+
str(k): (v.to_string() if is_pipeline_variable(v) else str(v))
18151815
for (k, v) in current_hyperparameters.items()
18161816
}
18171817

@@ -1879,7 +1879,9 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
18791879
if estimator.use_spot_instances:
18801880
if local_mode:
18811881
raise ValueError("Spot training is not supported in local mode.")
1882-
train_args["use_spot_instances"] = True
1882+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
1883+
# which is parsed during the Pipeline execution runtime
1884+
train_args["use_spot_instances"] = estimator.use_spot_instances
18831885

18841886
if estimator.checkpoint_s3_uri:
18851887
if local_mode:

src/sagemaker/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sagemaker.utils import unique_name_from_base
3838
from sagemaker.async_inference import AsyncInferenceConfig
3939
from sagemaker.predictor_async import AsyncPredictor
40+
from sagemaker.workflow import is_pipeline_variable
4041

4142
LOGGER = logging.getLogger("sagemaker")
4243

@@ -449,7 +450,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
449450
)
450451

451452
if repack and self.model_data is not None and self.entry_point is not None:
452-
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
453+
if is_pipeline_variable(self.model_data):
453454
# model is not yet there, defer repacking to later during pipeline execution
454455
return
455456

src/sagemaker/parameter.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import json
1717

18-
from sagemaker.workflow.entities import PipelineVariable
18+
from sagemaker.workflow import is_pipeline_variable
1919

2020

2121
class ParameterRange(object):
@@ -72,10 +72,10 @@ def as_tuning_range(self, name):
7272
return {
7373
"Name": name,
7474
"MinValue": str(self.min_value)
75-
if not isinstance(self.min_value, PipelineVariable)
75+
if not is_pipeline_variable(self.min_value)
7676
else self.min_value.to_string(),
7777
"MaxValue": str(self.max_value)
78-
if not isinstance(self.max_value, PipelineVariable)
78+
if not is_pipeline_variable(self.max_value)
7979
else self.max_value.to_string(),
8080
"ScalingType": self.scaling_type,
8181
}
@@ -110,9 +110,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
110110
This input will be converted into a list of strings.
111111
"""
112112
values = values if isinstance(values, list) else [values]
113-
self.values = [
114-
str(v) if not isinstance(v, PipelineVariable) else v.to_string() for v in values
115-
]
113+
self.values = [str(v) if not is_pipeline_variable(v) else v.to_string() for v in values]
116114

117115
def as_tuning_range(self, name):
118116
"""Represent the parameter range as a dictionary.

src/sagemaker/processing.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
from sagemaker.local import LocalSession
3535
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
3636
from sagemaker.session import Session
37+
from sagemaker.workflow import is_pipeline_variable
3738
from sagemaker.workflow.properties import Properties
3839
from sagemaker.workflow.parameters import Parameter
3940
from sagemaker.workflow.entities import Expression
4041
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4142
from sagemaker.apiutils._base_types import ApiObject
4243
from sagemaker.s3 import S3Uploader
4344

44-
4545
logger = logging.getLogger(__name__)
4646

4747

@@ -233,6 +233,12 @@ def _normalize_args(
233233
kms_key (str): The ARN of the KMS key that is used to encrypt the
234234
user code file (default: None).
235235
"""
236+
if code and is_pipeline_variable(code):
237+
raise ValueError(
238+
"code argument has to be a valid S3 URI or local file path "
239+
+ "rather than a pipeline variable"
240+
)
241+
236242
self._current_job_name = self._generate_current_job_name(job_name=job_name)
237243

238244
inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)

src/sagemaker/session.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,8 @@ def _get_train_request( # noqa: C901
763763
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic
764764

765765
if use_spot_instances:
766+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
767+
# which is parsed during the Pipeline execution runtime
766768
train_request["EnableManagedSpotTraining"] = use_spot_instances
767769

768770
if checkpoint_s3_uri:
@@ -2340,13 +2342,17 @@ def _map_training_config(
23402342
training_job_definition["VpcConfig"] = vpc_config
23412343

23422344
if enable_network_isolation:
2343-
training_job_definition["EnableNetworkIsolation"] = True
2345+
training_job_definition["EnableNetworkIsolation"] = enable_network_isolation
23442346

23452347
if encrypt_inter_container_traffic:
2346-
training_job_definition["EnableInterContainerTrafficEncryption"] = True
2348+
training_job_definition[
2349+
"EnableInterContainerTrafficEncryption"
2350+
] = encrypt_inter_container_traffic
23472351

23482352
if use_spot_instances:
2349-
training_job_definition["EnableManagedSpotTraining"] = True
2353+
# use_spot_instances may be a Pipeline ParameterBoolean object
2354+
# which is parsed during the Pipeline execution runtime
2355+
training_job_definition["EnableManagedSpotTraining"] = use_spot_instances
23502356

23512357
if checkpoint_s3_uri:
23522358
checkpoint_config = {"S3Uri": checkpoint_s3_uri}

src/sagemaker/tensorflow/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.deprecations import removed_kwargs
2222
from sagemaker.predictor import Predictor
2323
from sagemaker.serializers import JSONSerializer
24+
from sagemaker.workflow import is_pipeline_variable
2425

2526

2627
class TensorFlowPredictor(Predictor):
@@ -335,7 +336,9 @@ def prepare_container_def(
335336
)
336337
env = self._get_container_env()
337338

338-
if self.entry_point:
339+
# If self.model_data is pipeline variable, model is not yet there.
340+
# So defer repacking to later during pipeline execution
341+
if self.entry_point and not is_pipeline_variable(self.model_data):
339342
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
340343
self.key_prefix, self.name, image_uri
341344
)

src/sagemaker/tuner.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,10 @@
3838
IntegerParameter,
3939
ParameterRange,
4040
)
41-
from sagemaker.workflow.entities import PipelineVariable
42-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
43-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
44-
from sagemaker.workflow.functions import Join as PipelineJoin
4541

4642
from sagemaker.session import Session
4743
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
44+
from sagemaker.workflow import is_pipeline_variable
4845

4946
AMAZON_ESTIMATOR_MODULE = "sagemaker"
5047
AMAZON_ESTIMATOR_CLS_NAMES = {
@@ -64,18 +61,6 @@
6461
logger = logging.getLogger(__name__)
6562

6663

67-
def is_pipeline_parameters(value):
68-
"""Determine if a value is a pipeline parameter or function representation
69-
70-
Args:
71-
value (float or int): The value to be verified.
72-
73-
Returns:
74-
bool: True if it is, False otherwise.
75-
"""
76-
return isinstance(value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
77-
78-
7964
class WarmStartTypes(Enum):
8065
"""Warm Start Configuration type.
8166
@@ -377,7 +362,7 @@ def _prepare_static_hyperparameters(
377362
"""Prepare static hyperparameters for one estimator before tuning."""
378363
# Remove any hyperparameter that will be tuned
379364
static_hyperparameters = {
380-
str(k): str(v) if not isinstance(v, PipelineVariable) else v.to_string()
365+
str(k): str(v) if not is_pipeline_variable(v) else v.to_string()
381366
for (k, v) in estimator.hyperparameters().items()
382367
}
383368
for hyperparameter_name in hyperparameter_ranges.keys():

src/sagemaker/workflow/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,18 @@
2020
from sagemaker.workflow.properties import Properties
2121

2222
PipelineNonPrimitiveInputTypes = Union[ExecutionVariable, Expression, Parameter, Properties]
23+
24+
25+
def is_pipeline_variable(var: object) -> bool:
26+
"""Check if the variable is a pipeline variable
27+
28+
Args:
29+
var (object): The variable to be verified.
30+
Returns:
31+
bool: True if it is, False otherwise.
32+
"""
33+
34+
# Currently Expression is on top of all kinds of pipeline variables
35+
# as well as PipelineExperimentConfigProperty and PropertyFile
36+
# TODO: We should deprecate the Expression and replace it with PipelineVariable
37+
return isinstance(var, Expression)

src/sagemaker/workflow/_repack_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
3939
4040
Args:
4141
inference_script (str): The path to the custom entry point.
42-
model_archive (str): The name of the model TAR archive.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
4343
dependencies (str): A space-delimited string of paths to custom dependencies.
4444
source_dir (str): The path to a custom source directory.
4545
"""
4646

4747
# the data directory contains a model archive generated by a previous training job
4848
data_directory = "/opt/ml/input/data/training"
49-
model_path = os.path.join(data_directory, model_archive)
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
5050

5151
# create a temporary directory
5252
with tempfile.TemporaryDirectory() as tmp:

src/sagemaker/workflow/_utils.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,6 @@ def __init__(
134134
self._model_data = model_data
135135
self.sagemaker_session = sagemaker_session
136136
self.role = role
137-
if isinstance(model_data, Properties):
138-
self._model_prefix = model_data
139-
self._model_archive = "model.tar.gz"
140-
else:
141-
self._model_prefix = "/".join(self._model_data.split("/")[:-1])
142-
self._model_archive = self._model_data.split("/")[-1]
143137
self._entry_point = entry_point
144138
self._entry_point_basename = os.path.basename(self._entry_point)
145139
self._source_dir = source_dir
@@ -161,7 +155,7 @@ def __init__(
161155
role=self.role,
162156
hyperparameters={
163157
"inference_script": self._entry_point_basename,
164-
"model_archive": self._model_archive,
158+
"model_archive": self._model_data,
165159
"dependencies": dependencies_hyperparameter,
166160
"source_dir": self._source_dir,
167161
},
@@ -170,7 +164,7 @@ def __init__(
170164
**kwargs,
171165
)
172166
repacker.disable_profiler = True
173-
inputs = TrainingInput(self._model_prefix)
167+
inputs = TrainingInput(self._model_data)
174168

175169
# super!
176170
super(_RepackModelStep, self).__init__(

src/sagemaker/workflow/airflow.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
184184
train_config["VpcConfig"] = job_config["vpc_config"]
185185

186186
if estimator.use_spot_instances:
187-
train_config["EnableManagedSpotTraining"] = True
187+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
188+
# which is parsed during the Pipeline execution runtime
189+
train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances
188190

189191
if estimator.hyperparameters() is not None:
190192
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

src/sagemaker/workflow/clarify_check_step.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
3838
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
3939
from sagemaker.utils import name_from_base
40-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
41-
from sagemaker.workflow.entities import RequestType, Expression
40+
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
41+
from sagemaker.workflow.entities import RequestType
4242
from sagemaker.workflow.properties import Properties
4343
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
4444
from sagemaker.workflow.check_job_config import CheckJobConfig
@@ -193,18 +193,15 @@ def __init__(
193193
+ "DataBiasCheckConfig, ModelBiasCheckConfig or ModelExplainabilityCheckConfig"
194194
)
195195

196-
if isinstance(
197-
clarify_check_config.data_config.s3_analysis_config_output_path,
198-
(ExecutionVariable, Expression, Parameter, Properties),
199-
):
196+
if is_pipeline_variable(clarify_check_config.data_config.s3_analysis_config_output_path):
200197
raise RuntimeError(
201198
"s3_analysis_config_output_path cannot be of type "
202199
+ "ExecutionVariable/Expression/Parameter/Properties"
203200
)
204201

205-
if not clarify_check_config.data_config.s3_analysis_config_output_path and isinstance(
206-
clarify_check_config.data_config.s3_output_path,
207-
(ExecutionVariable, Expression, Parameter, Properties),
202+
if (
203+
not clarify_check_config.data_config.s3_analysis_config_output_path
204+
and is_pipeline_variable(clarify_check_config.data_config.s3_output_path)
208205
):
209206
raise RuntimeError(
210207
"`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter"

src/sagemaker/workflow/condition_step.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from sagemaker.workflow.step_collections import StepCollection
2727
from sagemaker.workflow.utilities import list_to_request
2828
from sagemaker.workflow.entities import (
29-
Expression,
3029
RequestType,
30+
PipelineVariable,
3131
)
3232
from sagemaker.workflow.properties import (
3333
Properties,
@@ -95,7 +95,7 @@ def properties(self):
9595

9696

9797
@attr.s
98-
class JsonGet(Expression): # pragma: no cover
98+
class JsonGet(PipelineVariable): # pragma: no cover
9999
"""Get JSON properties from PropertyFiles.
100100
101101
Attributes:

src/sagemaker/workflow/conditions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import attr
2424

25+
from sagemaker.workflow import is_pipeline_variable
2526
from sagemaker.workflow.entities import (
2627
DefaultEnumMeta,
2728
Entity,
@@ -261,6 +262,6 @@ def primitive_or_expr(
261262
Returns:
262263
Either the expression of the value or the primitive value.
263264
"""
264-
if isinstance(value, (ExecutionVariable, Expression, Parameter, Properties)):
265+
if is_pipeline_variable(value):
265266
return value.expr
266267
return value

src/sagemaker/workflow/pipeline.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -339,16 +339,20 @@ def interpolate(
339339
Args:
340340
request_obj (RequestType): The request dict.
341341
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
342+
lambda_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
342343
343344
Returns:
344345
RequestType: The request dict with Parameter values replaced by their expression.
345346
"""
346-
request_obj_copy = deepcopy(request_obj)
347-
return _interpolate(
348-
request_obj_copy,
349-
callback_output_to_step_map=callback_output_to_step_map,
350-
lambda_output_to_step_map=lambda_output_to_step_map,
351-
)
347+
try:
348+
request_obj_copy = deepcopy(request_obj)
349+
return _interpolate(
350+
request_obj_copy,
351+
callback_output_to_step_map=callback_output_to_step_map,
352+
lambda_output_to_step_map=lambda_output_to_step_map,
353+
)
354+
except TypeError as type_err:
355+
raise TypeError("Not able to interpolate Pipeline definition: %s" % type_err)
352356

353357

354358
def _interpolate(

0 commit comments

Comments
 (0)