Skip to content

Commit e464689

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
fix: Fix Pipeline variables related customer issues (aws#2959)
Co-authored-by: Dewen Qi <[email protected]>
1 parent bc3825e commit e464689

20 files changed

+430
-68
lines changed

src/sagemaker/estimator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,9 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
18791879
if estimator.use_spot_instances:
18801880
if local_mode:
18811881
raise ValueError("Spot training is not supported in local mode.")
1882-
train_args["use_spot_instances"] = True
1882+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
1883+
# which is parsed during the Pipeline execution runtime
1884+
train_args["use_spot_instances"] = estimator.use_spot_instances
18831885

18841886
if estimator.checkpoint_s3_uri:
18851887
if local_mode:

src/sagemaker/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sagemaker.utils import unique_name_from_base
3838
from sagemaker.async_inference import AsyncInferenceConfig
3939
from sagemaker.predictor_async import AsyncPredictor
40+
from sagemaker.workflow.entities import PipelineVariable
4041

4142
LOGGER = logging.getLogger("sagemaker")
4243

@@ -443,7 +444,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
443444
)
444445

445446
if repack and self.model_data is not None and self.entry_point is not None:
446-
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
447+
if isinstance(self.model_data, PipelineVariable):
447448
# model is not yet there, defer repacking to later during pipeline execution
448449
return
449450

src/sagemaker/processing.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from sagemaker.session import Session
3737
from sagemaker.workflow.properties import Properties
3838
from sagemaker.workflow.parameters import Parameter
39-
from sagemaker.workflow.entities import Expression
39+
from sagemaker.workflow.entities import Expression, PipelineVariable
4040
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4141
from sagemaker.apiutils._base_types import ApiObject
4242
from sagemaker.s3 import S3Uploader
@@ -233,6 +233,12 @@ def _normalize_args(
233233
kms_key (str): The ARN of the KMS key that is used to encrypt the
234234
user code file (default: None).
235235
"""
236+
if code and isinstance(code, PipelineVariable):
237+
raise ValueError(
238+
"code argument has to be a valid S3 URI or local file path "
239+
+ "rather than a pipeline variable"
240+
)
241+
236242
self._current_job_name = self._generate_current_job_name(job_name=job_name)
237243

238244
inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)

src/sagemaker/session.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,8 @@ def _get_train_request( # noqa: C901
763763
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic
764764

765765
if use_spot_instances:
766+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
767+
# which is parsed during the Pipeline execution runtime
766768
train_request["EnableManagedSpotTraining"] = use_spot_instances
767769

768770
if checkpoint_s3_uri:
@@ -2340,13 +2342,17 @@ def _map_training_config(
23402342
training_job_definition["VpcConfig"] = vpc_config
23412343

23422344
if enable_network_isolation:
2343-
training_job_definition["EnableNetworkIsolation"] = True
2345+
training_job_definition["EnableNetworkIsolation"] = enable_network_isolation
23442346

23452347
if encrypt_inter_container_traffic:
2346-
training_job_definition["EnableInterContainerTrafficEncryption"] = True
2348+
training_job_definition[
2349+
"EnableInterContainerTrafficEncryption"
2350+
] = encrypt_inter_container_traffic
23472351

23482352
if use_spot_instances:
2349-
training_job_definition["EnableManagedSpotTraining"] = True
2353+
# use_spot_instances may be a Pipeline ParameterBoolean object
2354+
# which is parsed during the Pipeline execution runtime
2355+
training_job_definition["EnableManagedSpotTraining"] = use_spot_instances
23502356

23512357
if checkpoint_s3_uri:
23522358
checkpoint_config = {"S3Uri": checkpoint_s3_uri}

src/sagemaker/tensorflow/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.deprecations import removed_kwargs
2222
from sagemaker.predictor import Predictor
2323
from sagemaker.serializers import JSONSerializer
24+
from sagemaker.workflow.entities import PipelineVariable
2425

2526

2627
class TensorFlowPredictor(Predictor):
@@ -330,7 +331,9 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
330331
image_uri = self._get_image_uri(instance_type, accelerator_type)
331332
env = self._get_container_env()
332333

333-
if self.entry_point:
334+
# If self.model_data is pipeline variable, model is not yet there.
335+
# So defer repacking to later during pipeline execution
336+
if self.entry_point and not isinstance(self.model_data, PipelineVariable):
334337
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
335338
self.key_prefix, self.name, image_uri
336339
)

src/sagemaker/tuner.py

-15
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@
3939
ParameterRange,
4040
)
4141
from sagemaker.workflow.entities import PipelineVariable
42-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
43-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
44-
from sagemaker.workflow.functions import Join as PipelineJoin
4542

4643
from sagemaker.session import Session
4744
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
@@ -64,18 +61,6 @@
6461
logger = logging.getLogger(__name__)
6562

6663

67-
def is_pipeline_parameters(value):
68-
"""Determine if a value is a pipeline parameter or function representation
69-
70-
Args:
71-
value (float or int): The value to be verified.
72-
73-
Returns:
74-
bool: True if it is, False otherwise.
75-
"""
76-
return isinstance(value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
77-
78-
7964
class WarmStartTypes(Enum):
8065
"""Warm Start Configuration type.
8166

src/sagemaker/workflow/_repack_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
3939
4040
Args:
4141
inference_script (str): The path to the custom entry point.
42-
model_archive (str): The name of the model TAR archive.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
4343
dependencies (str): A space-delimited string of paths to custom dependencies.
4444
source_dir (str): The path to a custom source directory.
4545
"""
4646

4747
# the data directory contains a model archive generated by a previous training job
4848
data_directory = "/opt/ml/input/data/training"
49-
model_path = os.path.join(data_directory, model_archive)
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
5050

5151
# create a temporary directory
5252
with tempfile.TemporaryDirectory() as tmp:

src/sagemaker/workflow/_utils.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,6 @@ def __init__(
137137
self._model_data = model_data
138138
self.sagemaker_session = sagemaker_session
139139
self.role = role
140-
if isinstance(model_data, Properties):
141-
self._model_prefix = model_data
142-
self._model_archive = "model.tar.gz"
143-
else:
144-
self._model_prefix = "/".join(self._model_data.split("/")[:-1])
145-
self._model_archive = self._model_data.split("/")[-1]
146140
self._entry_point = entry_point
147141
self._entry_point_basename = os.path.basename(self._entry_point)
148142
self._source_dir = source_dir
@@ -164,7 +158,7 @@ def __init__(
164158
role=self.role,
165159
hyperparameters={
166160
"inference_script": self._entry_point_basename,
167-
"model_archive": self._model_archive,
161+
"model_archive": self._model_data,
168162
"dependencies": dependencies_hyperparameter,
169163
"source_dir": self._source_dir,
170164
},
@@ -173,7 +167,7 @@ def __init__(
173167
**kwargs,
174168
)
175169
repacker.disable_profiler = True
176-
inputs = TrainingInput(self._model_prefix)
170+
inputs = TrainingInput(self._model_data)
177171

178172
# super!
179173
super(_RepackModelStep, self).__init__(

src/sagemaker/workflow/airflow.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
184184
train_config["VpcConfig"] = job_config["vpc_config"]
185185

186186
if estimator.use_spot_instances:
187-
train_config["EnableManagedSpotTraining"] = True
187+
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
188+
# which is parsed during the Pipeline execution runtime
189+
train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances
188190

189191
if estimator.hyperparameters() is not None:
190192
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

src/sagemaker/workflow/clarify_check_step.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
3838
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
3939
from sagemaker.utils import name_from_base
40-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
41-
from sagemaker.workflow.entities import RequestType, Expression
40+
from sagemaker.workflow import PipelineNonPrimitiveInputTypes
41+
from sagemaker.workflow.entities import RequestType, PipelineVariable
4242
from sagemaker.workflow.properties import Properties
4343
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
4444
from sagemaker.workflow.check_job_config import CheckJobConfig
@@ -194,17 +194,15 @@ def __init__(
194194
)
195195

196196
if isinstance(
197-
clarify_check_config.data_config.s3_analysis_config_output_path,
198-
(ExecutionVariable, Expression, Parameter, Properties),
197+
clarify_check_config.data_config.s3_analysis_config_output_path, PipelineVariable
199198
):
200199
raise RuntimeError(
201200
"s3_analysis_config_output_path cannot be of type "
202201
+ "ExecutionVariable/Expression/Parameter/Properties"
203202
)
204203

205204
if not clarify_check_config.data_config.s3_analysis_config_output_path and isinstance(
206-
clarify_check_config.data_config.s3_output_path,
207-
(ExecutionVariable, Expression, Parameter, Properties),
205+
clarify_check_config.data_config.s3_output_path, PipelineVariable
208206
):
209207
raise RuntimeError(
210208
"`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter"

src/sagemaker/workflow/conditions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Expression,
2929
PrimitiveType,
3030
RequestType,
31+
PipelineVariable,
3132
)
3233
from sagemaker.workflow.execution_variables import ExecutionVariable
3334
from sagemaker.workflow.parameters import Parameter
@@ -261,6 +262,6 @@ def primitive_or_expr(
261262
Returns:
262263
Either the expression of the value or the primitive value.
263264
"""
264-
if isinstance(value, (ExecutionVariable, Expression, Parameter, Properties)):
265+
if isinstance(value, PipelineVariable):
265266
return value.expr
266267
return value

src/sagemaker/workflow/quality_check_step.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from sagemaker import s3
2323
from sagemaker.model_monitor import ModelMonitor
2424
from sagemaker.processing import ProcessingOutput, ProcessingJob, Processor, ProcessingInput
25-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
25+
from sagemaker.workflow import PipelineNonPrimitiveInputTypes
2626

27-
from sagemaker.workflow.entities import RequestType, Expression
27+
from sagemaker.workflow.entities import RequestType, PipelineVariable
2828
from sagemaker.workflow.properties import (
2929
Properties,
3030
)
@@ -279,7 +279,7 @@ def _generate_baseline_job_inputs(self):
279279
_CONTAINER_BASE_PATH, _CONTAINER_INPUT_PATH, _BASELINE_DATASET_INPUT_NAME
280280
)
281281
)
282-
if isinstance(baseline_dataset, (ExecutionVariable, Expression, Parameter, Properties)):
282+
if isinstance(baseline_dataset, PipelineVariable):
283283
baseline_dataset_input = ProcessingInput(
284284
source=self.quality_check_config.baseline_dataset,
285285
destination=baseline_dataset_des,

tests/integ/sagemaker/workflow/test_model_registration.py

+99
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from botocore.exceptions import WaiterError
2121

2222
import tests
23+
from sagemaker.tensorflow import TensorFlow, TensorFlowModel
2324
from tests.integ.retry import retries
2425
from sagemaker.drift_check_baselines import DriftCheckBaselines
2526
from sagemaker import (
@@ -745,3 +746,101 @@ def test_model_registration_with_model_repack(
745746
pipeline.delete()
746747
except Exception:
747748
pass
749+
750+
751+
def test_model_registration_with_tensorflow_model_with_pipeline_model(
752+
sagemaker_session, role, tf_full_version, tf_full_py_version, pipeline_name, region_name
753+
):
754+
base_dir = os.path.join(DATA_DIR, "tensorflow_mnist")
755+
entry_point = os.path.join(base_dir, "mnist_v2.py")
756+
input_path = sagemaker_session.upload_data(
757+
path=os.path.join(base_dir, "data"),
758+
key_prefix="integ-test-data/tf-scriptmode/mnist/training",
759+
)
760+
inputs = TrainingInput(s3_data=input_path)
761+
762+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
763+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
764+
765+
tensorflow_estimator = TensorFlow(
766+
entry_point=entry_point,
767+
role=role,
768+
instance_count=instance_count,
769+
instance_type=instance_type,
770+
framework_version=tf_full_version,
771+
py_version=tf_full_py_version,
772+
sagemaker_session=sagemaker_session,
773+
)
774+
step_train = TrainingStep(
775+
name="MyTrain",
776+
estimator=tensorflow_estimator,
777+
inputs=inputs,
778+
)
779+
780+
model = TensorFlowModel(
781+
entry_point=entry_point,
782+
framework_version="2.4",
783+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
784+
role=role,
785+
sagemaker_session=sagemaker_session,
786+
)
787+
788+
pipeline_model = PipelineModel(
789+
name="MyModelPipeline", models=[model], role=role, sagemaker_session=sagemaker_session
790+
)
791+
792+
step_register_model = RegisterModel(
793+
name="MyRegisterModel",
794+
model=pipeline_model,
795+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
796+
content_types=["application/json"],
797+
response_types=["application/json"],
798+
inference_instances=["ml.t2.medium", "ml.m5.large"],
799+
transform_instances=["ml.m5.large"],
800+
model_package_group_name=f"{pipeline_name}TestModelPackageGroup",
801+
)
802+
803+
pipeline = Pipeline(
804+
name=pipeline_name,
805+
parameters=[
806+
instance_count,
807+
instance_type,
808+
],
809+
steps=[step_train, step_register_model],
810+
sagemaker_session=sagemaker_session,
811+
)
812+
813+
try:
814+
response = pipeline.create(role)
815+
create_arn = response["PipelineArn"]
816+
817+
assert re.match(
818+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
819+
create_arn,
820+
)
821+
822+
for _ in retries(
823+
max_retry_count=5,
824+
exception_message_prefix="Waiting for a successful execution of pipeline",
825+
seconds_to_sleep=10,
826+
):
827+
execution = pipeline.start(parameters={})
828+
assert re.match(
829+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
830+
execution.arn,
831+
)
832+
try:
833+
execution.wait(delay=30, max_attempts=60)
834+
except WaiterError:
835+
pass
836+
execution_steps = execution.list_steps()
837+
838+
assert len(execution_steps) == 3
839+
for step in execution_steps:
840+
assert step["StepStatus"] == "Succeeded"
841+
break
842+
finally:
843+
try:
844+
pipeline.delete()
845+
except Exception:
846+
pass

0 commit comments

Comments
 (0)