Skip to content

Commit edd2463

Browse files
authored
fix: Prevent passing PipelineVariable object into image_uris.retrieve (#3054)
1 parent 617bfab commit edd2463

13 files changed

+107
-62
lines changed

src/sagemaker/image_uris.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2424
from sagemaker.spark import defaults
2525
from sagemaker.jumpstart import artifacts
26+
from sagemaker.workflow import is_pipeline_variable
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -104,11 +105,17 @@ def retrieve(
104105
105106
Raises:
106107
NotImplementedError: If the scope is not supported.
107-
ValueError: If the combination of arguments specified is not supported.
108+
ValueError: If the combination of arguments specified is not supported or
109+
any PipelineVariable object is passed in.
108110
VulnerableJumpStartModelError: If any of the dependencies required by the script have
109111
known security vulnerabilities.
110112
DeprecatedJumpStartModelError: If the version of the model is deprecated.
111113
"""
114+
args = dict(locals())
115+
for name, val in args.items():
116+
if is_pipeline_variable(val):
117+
raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val)))
118+
112119
if is_jumpstart_model_input(model_id, model_version):
113120
return artifacts._retrieve_image_uri(
114121
model_id,

src/sagemaker/workflow/entities.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def __add__(self, other: Union[Expression, PrimitiveType]):
7878

7979
def __str__(self):
8080
"""Override built-in String function for PipelineVariable"""
81-
raise TypeError("Pipeline variables do not support __str__ operation.")
81+
raise TypeError(
82+
"Pipeline variables do not support __str__ operation. "
83+
"Please use `.to_string()` to convert it to string type in execution time"
84+
"or use `.expr` to translate it to Json for display purpose in Python SDK."
85+
)
8286

8387
def __int__(self):
8488
"""Override built-in Integer function for PipelineVariable"""

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,12 @@ def test_conditional_pytorch_training_model_registration(
9090
inputs = TrainingInput(s3_data=input_path)
9191

9292
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
93-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
93+
instance_type = "ml.m5.xlarge"
9494
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
9595
in_condition_input = ParameterString(name="Foo", default_value="Foo")
9696

97+
# If image_uri is not provided, the instance_type should not be a pipeline variable
98+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
9799
pytorch_estimator = PyTorch(
98100
entry_point=entry_point,
99101
role=role,
@@ -153,7 +155,6 @@ def test_conditional_pytorch_training_model_registration(
153155
in_condition_input,
154156
good_enough_input,
155157
instance_count,
156-
instance_type,
157158
],
158159
steps=[step_train, step_cond],
159160
sagemaker_session=sagemaker_session,
@@ -259,8 +260,10 @@ def test_sklearn_xgboost_sip_model_registration(
259260
prefix = "sip"
260261
bucket_name = sagemaker_session.default_bucket()
261262
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
262-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
263+
instance_type = "ml.m5.xlarge"
263264

265+
# The instance_type should not be a pipeline variable
266+
# since it is used to retrieve image_uri in compile time (PySDK)
264267
sklearn_processor = SKLearnProcessor(
265268
role=role,
266269
instance_type=instance_type,
@@ -331,6 +334,8 @@ def test_sklearn_xgboost_sip_model_registration(
331334
source_dir = base_dir
332335
code_location = "s3://{0}/{1}/code".format(bucket_name, prefix)
333336

337+
# If image_uri is not provided, the instance_type should not be a pipeline variable
338+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
334339
estimator = XGBoost(
335340
entry_point=entry_point,
336341
source_dir=source_dir,
@@ -416,7 +421,6 @@ def test_sklearn_xgboost_sip_model_registration(
416421
train_data_path_param,
417422
val_data_path_param,
418423
model_path_param,
419-
instance_type,
420424
instance_count,
421425
output_path_param,
422426
],
@@ -462,7 +466,7 @@ def test_model_registration_with_drift_check_baselines(
462466
pipeline_name,
463467
):
464468
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
465-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
469+
instance_type = "ml.m5.xlarge"
466470

467471
# upload model data to s3
468472
model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz")
@@ -550,6 +554,9 @@ def test_model_registration_with_drift_check_baselines(
550554
),
551555
)
552556
customer_metadata_properties = {"key1": "value1"}
557+
558+
# If image_uri is not provided, the instance_type should not be a pipeline variable
559+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
553560
estimator = XGBoost(
554561
entry_point="training.py",
555562
source_dir=os.path.join(DATA_DIR, "sip"),
@@ -579,7 +586,6 @@ def test_model_registration_with_drift_check_baselines(
579586
parameters=[
580587
model_uri_param,
581588
metrics_uri_param,
582-
instance_type,
583589
instance_count,
584590
],
585591
steps=[step_register],
@@ -667,9 +673,11 @@ def test_model_registration_with_model_repack(
667673
inputs = TrainingInput(s3_data=input_path)
668674

669675
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
670-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
676+
instance_type = "ml.m5.xlarge"
671677
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
672678

679+
# If image_uri is not provided, the instance_type should not be a pipeline variable
680+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
673681
pytorch_estimator = PyTorch(
674682
entry_point=entry_point,
675683
role=role,
@@ -724,7 +732,7 @@ def test_model_registration_with_model_repack(
724732

725733
pipeline = Pipeline(
726734
name=pipeline_name,
727-
parameters=[good_enough_input, instance_count, instance_type],
735+
parameters=[good_enough_input, instance_count],
728736
steps=[step_cond],
729737
sagemaker_session=sagemaker_session,
730738
)
@@ -767,8 +775,10 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
767775
inputs = TrainingInput(s3_data=input_path)
768776

769777
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
770-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
778+
instance_type = "ml.m5.xlarge"
771779

780+
# If image_uri is not provided, the instance_type should not be a pipeline variable
781+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
772782
tensorflow_estimator = TensorFlow(
773783
entry_point=entry_point,
774784
role=role,
@@ -809,10 +819,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
809819

810820
pipeline = Pipeline(
811821
name=pipeline_name,
812-
parameters=[
813-
instance_count,
814-
instance_type,
815-
],
822+
parameters=[instance_count],
816823
steps=[step_train, step_register_model],
817824
sagemaker_session=sagemaker_session,
818825
)

tests/integ/sagemaker/workflow/test_model_steps.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,16 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
8383
inputs = TrainingInput(s3_data=input_path)
8484

8585
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
86-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
8786

87+
# If image_uri is not provided, the instance_type should not be a pipeline variable
88+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
8889
pytorch_estimator = PyTorch(
8990
entry_point=entry_point,
9091
role=role,
9192
framework_version="1.5.0",
9293
py_version="py3",
9394
instance_count=instance_count,
94-
instance_type=instance_type,
95+
instance_type="ml.m5.xlarge",
9596
sagemaker_session=pipeline_session,
9697
)
9798
train_step_args = pytorch_estimator.fit(inputs=inputs)
@@ -140,7 +141,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
140141
)
141142
pipeline = Pipeline(
142143
name=pipeline_name,
143-
parameters=[instance_count, instance_type],
144+
parameters=[instance_count],
144145
steps=[step_train, step_model_regis, step_model_create, step_fail],
145146
sagemaker_session=pipeline_session,
146147
)
@@ -203,15 +204,16 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
203204
inputs = TrainingInput(s3_data=input_path)
204205

205206
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
206-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
207207

208+
# If image_uri is not provided, the instance_type should not be a pipeline variable
209+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
208210
pytorch_estimator = PyTorch(
209211
entry_point=entry_point,
210212
role=role,
211213
framework_version="1.5.0",
212214
py_version="py3",
213215
instance_count=instance_count,
214-
instance_type=instance_type,
216+
instance_type="ml.m5.xlarge",
215217
sagemaker_session=pipeline_session,
216218
output_kms_key=kms_key,
217219
)
@@ -267,7 +269,7 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
267269
)
268270
pipeline = Pipeline(
269271
name=pipeline_name,
270-
parameters=[instance_count, instance_type],
272+
parameters=[instance_count],
271273
steps=[step_train, step_model_regis, step_model_create, step_fail],
272274
sagemaker_session=pipeline_session,
273275
)
@@ -400,7 +402,6 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
400402
pipeline_name,
401403
):
402404
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
403-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
404405

405406
# upload model data to s3
406407
model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz")
@@ -488,10 +489,12 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
488489
),
489490
)
490491
customer_metadata_properties = {"key1": "value1"}
492+
# If image_uri is not provided, the instance_type should not be a pipeline variable
493+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
491494
estimator = XGBoost(
492495
entry_point="training.py",
493496
source_dir=os.path.join(DATA_DIR, "sip"),
494-
instance_type=instance_type,
497+
instance_type="ml.m5.xlarge",
495498
instance_count=instance_count,
496499
framework_version="0.90-2",
497500
sagemaker_session=pipeline_session,
@@ -524,7 +527,6 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
524527
parameters=[
525528
model_uri_param,
526529
metrics_uri_param,
527-
instance_type,
528530
instance_count,
529531
],
530532
steps=[step_model_register],
@@ -606,13 +608,14 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
606608
)
607609
inputs = TrainingInput(s3_data=input_path)
608610
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
609-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
610611

612+
# If image_uri is not provided, the instance_type should not be a pipeline variable
613+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
611614
tensorflow_estimator = TensorFlow(
612615
entry_point=entry_point,
613616
role=role,
614617
instance_count=instance_count,
615-
instance_type=instance_type,
618+
instance_type="ml.m5.xlarge",
616619
framework_version=tf_full_version,
617620
py_version=tf_full_py_version,
618621
sagemaker_session=pipeline_session,
@@ -645,10 +648,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
645648
)
646649
pipeline = Pipeline(
647650
name=pipeline_name,
648-
parameters=[
649-
instance_count,
650-
instance_type,
651-
],
651+
parameters=[instance_count],
652652
steps=[step_train, step_register_model],
653653
sagemaker_session=pipeline_session,
654654
)

tests/integ/sagemaker/workflow/test_retry.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
_REGISTER_MODEL_NAME_BASE,
3232
_CREATE_MODEL_NAME_BASE,
3333
)
34-
from sagemaker.workflow.parameters import (
35-
ParameterInteger,
36-
ParameterString,
37-
)
34+
from sagemaker.workflow.parameters import ParameterInteger
3835
from sagemaker.pytorch.estimator import PyTorch
3936
from sagemaker.workflow.pipeline import Pipeline
4037
from sagemaker.workflow.retry import (
@@ -185,9 +182,11 @@ def test_model_registration_with_model_repack(
185182
inputs = TrainingInput(s3_data=input_path)
186183

187184
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
188-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
185+
instance_type = "ml.m5.xlarge"
189186
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
190187

188+
# If image_uri is not provided, the instance_type should not be a pipeline variable
189+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
191190
pytorch_estimator = PyTorch(
192191
entry_point=entry_point,
193192
role=role,
@@ -252,7 +251,7 @@ def test_model_registration_with_model_repack(
252251
)
253252
pipeline = Pipeline(
254253
name=pipeline_name,
255-
parameters=[good_enough_input, instance_count, instance_type],
254+
parameters=[good_enough_input, instance_count],
256255
steps=[step_train, step_cond],
257256
sagemaker_session=pipeline_session,
258257
)

tests/integ/sagemaker/workflow/test_training_steps.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_training_job_with_debugger_and_profiler(
5959
pytorch_training_latest_py_version,
6060
):
6161
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
62-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
62+
instance_type = "ml.m5.xlarge"
6363

6464
rules = [
6565
Rule.sagemaker(rule_configs.vanishing_gradient()),
@@ -78,6 +78,8 @@ def test_training_job_with_debugger_and_profiler(
7878
)
7979
inputs = TrainingInput(s3_data=input_path)
8080

81+
# If image_uri is not provided, the instance_type should not be a pipeline variable
82+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
8183
pytorch_estimator = PyTorch(
8284
entry_point=script_path,
8385
role="SageMakerRole",
@@ -98,7 +100,7 @@ def test_training_job_with_debugger_and_profiler(
98100

99101
pipeline = Pipeline(
100102
name=pipeline_name,
101-
parameters=[instance_count, instance_type],
103+
parameters=[instance_count],
102104
steps=[step_train],
103105
sagemaker_session=sagemaker_session,
104106
)

tests/integ/sagemaker/workflow/test_tuning_steps.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ def test_tuning_single_algo(
9191
inputs = TrainingInput(s3_data=input_path)
9292

9393
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
94-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
94+
instance_type = "ml.m5.xlarge"
9595

96+
# If image_uri is not provided, the instance_type should not be a pipeline variable
97+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
9698
pytorch_estimator = PyTorch(
9799
entry_point=entry_point,
98100
role=role,
@@ -167,7 +169,7 @@ def test_tuning_single_algo(
167169

168170
pipeline = Pipeline(
169171
name=pipeline_name,
170-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
172+
parameters=[instance_count, min_batch_size, max_batch_size],
171173
steps=[step_tune, step_best_model, step_second_best_model],
172174
sagemaker_session=pipeline_session,
173175
)
@@ -221,10 +223,12 @@ def test_tuning_multi_algos(
221223
)
222224

223225
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
224-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
226+
instance_type = "ml.m5.xlarge"
225227

226228
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
227229

230+
# The instance_type should not be a pipeline variable
231+
# since it is used to retrieve image_uri in compile time (PySDK)
228232
sklearn_processor = SKLearnProcessor(
229233
framework_version="0.20.0",
230234
instance_type=instance_type,
@@ -259,6 +263,8 @@ def test_tuning_multi_algos(
259263
json_get_hp = JsonGet(
260264
step_name=step_process.name, property_file=property_file, json_path="train_size"
261265
)
266+
# If image_uri is not provided, the instance_type should not be a pipeline variable
267+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
262268
pytorch_estimator = PyTorch(
263269
entry_point=entry_point,
264270
role=role,
@@ -307,7 +313,7 @@ def test_tuning_multi_algos(
307313

308314
pipeline = Pipeline(
309315
name=pipeline_name,
310-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
316+
parameters=[instance_count, min_batch_size, max_batch_size, static_hp_1],
311317
steps=[step_process, step_tune],
312318
sagemaker_session=sagemaker_session,
313319
)

0 commit comments

Comments
 (0)