Skip to content

Commit 134c6ac

Browse files
author
Dewen Qi
committed
fix: Fix Pipeline variables related customer issues
1 parent 6670e30 commit 134c6ac

19 files changed

+426
-87
lines changed

src/sagemaker/estimator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,9 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
18691869
if estimator.use_spot_instances:
18701870
if local_mode:
18711871
raise ValueError("Spot training is not supported in local mode.")
1872-
train_args["use_spot_instances"] = True
1872+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
1873+
# which is parsed during the Pipeline execution runtime
1874+
train_args["use_spot_instances"] = estimator.use_spot_instances
18731875

18741876
if estimator.checkpoint_s3_uri:
18751877
if local_mode:

src/sagemaker/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sagemaker.utils import unique_name_from_base
3838
from sagemaker.async_inference import AsyncInferenceConfig
3939
from sagemaker.predictor_async import AsyncPredictor
40+
from sagemaker.workflow import is_pipeline_variable
4041

4142
LOGGER = logging.getLogger("sagemaker")
4243

@@ -443,7 +444,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
443444
)
444445

445446
if repack and self.model_data is not None and self.entry_point is not None:
446-
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
447+
if is_pipeline_variable(self.model_data):
447448
# model is not yet there, defer repacking to later during pipeline execution
448449
return
449450

src/sagemaker/parameter.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
from __future__ import absolute_import
1515

1616
import json
17-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
18-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
19-
from sagemaker.workflow.functions import Join as PipelineJoin
17+
18+
from sagemaker.workflow import is_pipeline_variable
2019

2120

2221
class ParameterRange(object):
@@ -73,10 +72,10 @@ def as_tuning_range(self, name):
7372
return {
7473
"Name": name,
7574
"MinValue": str(self.min_value)
76-
if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
75+
if not is_pipeline_variable(self.min_value)
7776
else self.min_value,
7877
"MaxValue": str(self.max_value)
79-
if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
78+
if not is_pipeline_variable(self.max_value)
8079
else self.max_value,
8180
"ScalingType": self.scaling_type,
8281
}
@@ -111,10 +110,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
111110
This input will be converted into a list of strings.
112111
"""
113112
values = values if isinstance(values, list) else [values]
114-
self.values = [
115-
str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v
116-
for v in values
117-
]
113+
self.values = [str(v) if not is_pipeline_variable(v) else v for v in values]
118114

119115
def as_tuning_range(self, name):
120116
"""Represent the parameter range as a dictionary.

src/sagemaker/session.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,8 @@ def _get_train_request( # noqa: C901
763763
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic
764764

765765
if use_spot_instances:
766+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
767+
# which is parsed during the Pipeline execution runtime
766768
train_request["EnableManagedSpotTraining"] = use_spot_instances
767769

768770
if checkpoint_s3_uri:
@@ -2338,13 +2340,17 @@ def _map_training_config(
23382340
training_job_definition["VpcConfig"] = vpc_config
23392341

23402342
if enable_network_isolation:
2341-
training_job_definition["EnableNetworkIsolation"] = True
2343+
training_job_definition["EnableNetworkIsolation"] = enable_network_isolation
23422344

23432345
if encrypt_inter_container_traffic:
2344-
training_job_definition["EnableInterContainerTrafficEncryption"] = True
2346+
training_job_definition[
2347+
"EnableInterContainerTrafficEncryption"
2348+
] = encrypt_inter_container_traffic
23452349

23462350
if use_spot_instances:
2347-
training_job_definition["EnableManagedSpotTraining"] = True
2351+
# use_spot_instances may be a Pipeline ParameterBoolean object
2352+
# which is parsed during the Pipeline execution runtime
2353+
training_job_definition["EnableManagedSpotTraining"] = use_spot_instances
23482354

23492355
if checkpoint_s3_uri:
23502356
checkpoint_config = {"S3Uri": checkpoint_s3_uri}

src/sagemaker/tensorflow/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.deprecations import removed_kwargs
2222
from sagemaker.predictor import Predictor
2323
from sagemaker.serializers import JSONSerializer
24+
from sagemaker.workflow import is_pipeline_variable
2425

2526

2627
class TensorFlowPredictor(Predictor):
@@ -326,7 +327,9 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
326327
image_uri = self._get_image_uri(instance_type, accelerator_type)
327328
env = self._get_container_env()
328329

329-
if self.entry_point:
330+
# If self.model_data is pipeline variable, model is not yet there.
331+
# So defer repacking to later during pipeline execution
332+
if self.entry_point and not is_pipeline_variable(self.model_data):
330333
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
331334
self.key_prefix, self.name, image_uri
332335
)

src/sagemaker/tuner.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@
3838
IntegerParameter,
3939
ParameterRange,
4040
)
41-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
42-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
43-
from sagemaker.workflow.functions import Join as PipelineJoin
41+
from sagemaker.workflow import is_pipeline_variable
4442

4543
from sagemaker.session import Session
4644
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
@@ -63,18 +61,6 @@
6361
logger = logging.getLogger(__name__)
6462

6563

66-
def is_pipeline_parameters(value):
67-
"""Determine if a value is a pipeline parameter or function representation
68-
69-
Args:
70-
value (float or int): The value to be verified.
71-
72-
Returns:
73-
bool: True if it is, False otherwise.
74-
"""
75-
return isinstance(value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
76-
77-
7864
class WarmStartTypes(Enum):
7965
"""Warm Start Configuration type.
8066
@@ -376,9 +362,7 @@ def _prepare_static_hyperparameters(
376362
"""Prepare static hyperparameters for one estimator before tuning."""
377363
# Remove any hyperparameter that will be tuned
378364
static_hyperparameters = {
379-
str(k): str(v)
380-
if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin))
381-
else v
365+
str(k): str(v) if not is_pipeline_variable(v) else v
382366
for (k, v) in estimator.hyperparameters().items()
383367
}
384368
for hyperparameter_name in hyperparameter_ranges.keys():

src/sagemaker/workflow/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,15 @@
2020
from sagemaker.workflow.properties import Properties
2121

2222
PipelineNonPrimitiveInputTypes = Union[ExecutionVariable, Expression, Parameter, Properties]
23+
24+
25+
def is_pipeline_variable(var: object) -> bool:
26+
"""Check if the variable is a pipeline Parameter/Properties/Expression/ExecutionVariable
27+
28+
Args:
29+
var (object): The variable to be verified.
30+
31+
Returns:
32+
bool: True if it is, False otherwise.
33+
"""
34+
return isinstance(var, (Properties, Parameter, Expression, ExecutionVariable))

src/sagemaker/workflow/_repack_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
3939
4040
Args:
4141
inference_script (str): The path to the custom entry point.
42-
model_archive (str): The name of the model TAR archive.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
4343
dependencies (str): A space-delimited string of paths to custom dependencies.
4444
source_dir (str): The path to a custom source directory.
4545
"""
4646

4747
# the data directory contains a model archive generated by a previous training job
4848
data_directory = "/opt/ml/input/data/training"
49-
model_path = os.path.join(data_directory, model_archive)
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
5050

5151
# create a temporary directory
5252
with tempfile.TemporaryDirectory() as tmp:

src/sagemaker/workflow/_utils.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,6 @@ def __init__(
137137
self._model_data = model_data
138138
self.sagemaker_session = sagemaker_session
139139
self.role = role
140-
if isinstance(model_data, Properties):
141-
self._model_prefix = model_data
142-
self._model_archive = "model.tar.gz"
143-
else:
144-
self._model_prefix = "/".join(self._model_data.split("/")[:-1])
145-
self._model_archive = self._model_data.split("/")[-1]
146140
self._entry_point = entry_point
147141
self._entry_point_basename = os.path.basename(self._entry_point)
148142
self._source_dir = source_dir
@@ -164,7 +158,7 @@ def __init__(
164158
role=self.role,
165159
hyperparameters={
166160
"inference_script": self._entry_point_basename,
167-
"model_archive": self._model_archive,
161+
"model_archive": self._model_data,
168162
"dependencies": dependencies_hyperparameter,
169163
"source_dir": self._source_dir,
170164
},
@@ -173,7 +167,7 @@ def __init__(
173167
**kwargs,
174168
)
175169
repacker.disable_profiler = True
176-
inputs = TrainingInput(self._model_prefix)
170+
inputs = TrainingInput(self._model_data)
177171

178172
# super!
179173
super(_RepackModelStep, self).__init__(

src/sagemaker/workflow/airflow.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
184184
train_config["VpcConfig"] = job_config["vpc_config"]
185185

186186
if estimator.use_spot_instances:
187-
train_config["EnableManagedSpotTraining"] = True
187+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
188+
# which is parsed during the Pipeline execution runtime
189+
train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances
188190

189191
if estimator.hyperparameters() is not None:
190192
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

src/sagemaker/workflow/clarify_check_step.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
3838
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
3939
from sagemaker.utils import name_from_base
40-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
41-
from sagemaker.workflow.entities import RequestType, Expression
40+
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
41+
from sagemaker.workflow.entities import RequestType
4242
from sagemaker.workflow.properties import Properties
4343
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
4444
from sagemaker.workflow.check_job_config import CheckJobConfig
@@ -193,18 +193,15 @@ def __init__(
193193
+ "DataBiasCheckConfig, ModelBiasCheckConfig or ModelExplainabilityCheckConfig"
194194
)
195195

196-
if isinstance(
197-
clarify_check_config.data_config.s3_analysis_config_output_path,
198-
(ExecutionVariable, Expression, Parameter, Properties),
199-
):
196+
if is_pipeline_variable(clarify_check_config.data_config.s3_analysis_config_output_path):
200197
raise RuntimeError(
201198
"s3_analysis_config_output_path cannot be of type "
202199
+ "ExecutionVariable/Expression/Parameter/Properties"
203200
)
204201

205-
if not clarify_check_config.data_config.s3_analysis_config_output_path and isinstance(
206-
clarify_check_config.data_config.s3_output_path,
207-
(ExecutionVariable, Expression, Parameter, Properties),
202+
if (
203+
not clarify_check_config.data_config.s3_analysis_config_output_path
204+
and is_pipeline_variable(clarify_check_config.data_config.s3_output_path)
208205
):
209206
raise RuntimeError(
210207
"`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter"

src/sagemaker/workflow/conditions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import attr
2424

25+
from sagemaker.workflow import is_pipeline_variable
2526
from sagemaker.workflow.entities import (
2627
DefaultEnumMeta,
2728
Entity,
@@ -261,6 +262,6 @@ def primitive_or_expr(
261262
Returns:
262263
Either the expression of the value or the primitive value.
263264
"""
264-
if isinstance(value, (ExecutionVariable, Expression, Parameter, Properties)):
265+
if is_pipeline_variable(value):
265266
return value.expr
266267
return value

src/sagemaker/workflow/entities.py

+13
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,16 @@ class Expression(abc.ABC):
5757
@abc.abstractmethod
5858
def expr(self) -> RequestType:
5959
"""Get the expression structure for workflow service calls."""
60+
61+
def startswith(self, prefix, start=None, end=None) -> bool: # pylint: disable=unused-argument
62+
"""Simulate the Python string's built-in method: startswith
63+
64+
Args:
65+
prefix (str, tuple): The (tuple of) string to be checked.
66+
start (int): To set the start index of the matching boundary (default: None).
67+
end (int): To set the end index of the matching boundary (default: None).
68+
69+
Return:
70+
bool: always return False as Pipeline variables are parsed during execution runtime
71+
"""
72+
return False

src/sagemaker/workflow/quality_check_step.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from sagemaker import s3
2323
from sagemaker.model_monitor import ModelMonitor
2424
from sagemaker.processing import ProcessingOutput, ProcessingJob, Processor, ProcessingInput
25-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
25+
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, is_pipeline_variable
2626

27-
from sagemaker.workflow.entities import RequestType, Expression
27+
from sagemaker.workflow.entities import RequestType
2828
from sagemaker.workflow.properties import (
2929
Properties,
3030
)
@@ -279,7 +279,7 @@ def _generate_baseline_job_inputs(self):
279279
_CONTAINER_BASE_PATH, _CONTAINER_INPUT_PATH, _BASELINE_DATASET_INPUT_NAME
280280
)
281281
)
282-
if isinstance(baseline_dataset, (ExecutionVariable, Expression, Parameter, Properties)):
282+
if is_pipeline_variable(baseline_dataset):
283283
baseline_dataset_input = ProcessingInput(
284284
source=self.quality_check_config.baseline_dataset,
285285
destination=baseline_dataset_des,

0 commit comments

Comments
 (0)