Skip to content

Commit 69e2829

Browse files
Dewen Qiqidewenwhen
Dewen Qi
authored andcommitted
fix: Prevent passing PipelineVariable object into image_uris.retrieve
1 parent b4f06b3 commit 69e2829

12 files changed

+92
-47
lines changed

src/sagemaker/image_uris.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2424
from sagemaker.spark import defaults
2525
from sagemaker.jumpstart import artifacts
26+
from sagemaker.workflow import is_pipeline_variable
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -104,11 +105,17 @@ def retrieve(
104105
105106
Raises:
106107
NotImplementedError: If the scope is not supported.
107-
ValueError: If the combination of arguments specified is not supported.
108+
ValueError: If the combination of arguments specified is not supported or
109+
any PipelineVariable object is passed in.
108110
VulnerableJumpStartModelError: If any of the dependencies required by the script have
109111
known security vulnerabilities.
110112
DeprecatedJumpStartModelError: If the version of the model is deprecated.
111113
"""
114+
args = dict(locals())
115+
for name, val in args.items():
116+
if is_pipeline_variable(val):
117+
raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val)))
118+
112119
if is_jumpstart_model_input(model_id, model_version):
113120
return artifacts._retrieve_image_uri(
114121
model_id,

src/sagemaker/workflow/entities.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def __add__(self, other: Union[Expression, PrimitiveType]):
7878

7979
def __str__(self):
8080
"""Override built-in String function for PipelineVariable"""
81-
raise TypeError("Pipeline variables do not support __str__ operation.")
81+
raise TypeError(
82+
"Pipeline variables do not support __str__ operation. "
83+
"Please use `.to_string()` to convert it to string type in execution time"
84+
"or use `.expr` to translate it to Json for display purpose in Python SDK."
85+
)
8286

8387
def __int__(self):
8488
"""Override built-in Integer function for PipelineVariable"""

tests/integ/sagemaker/workflow/test_model_registration.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def test_conditional_pytorch_training_model_registration(
8484
inputs = TrainingInput(s3_data=input_path)
8585

8686
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
87-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
87+
instance_type = "ml.m5.xlarge"
8888
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
8989
in_condition_input = ParameterString(name="Foo", default_value="Foo")
9090

91+
# If image_uri is not provided, the instance_type should not be a pipeline variable
92+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
9193
pytorch_estimator = PyTorch(
9294
entry_point=entry_point,
9395
role=role,
@@ -146,7 +148,6 @@ def test_conditional_pytorch_training_model_registration(
146148
in_condition_input,
147149
good_enough_input,
148150
instance_count,
149-
instance_type,
150151
],
151152
steps=[step_cond],
152153
sagemaker_session=sagemaker_session,
@@ -252,8 +253,10 @@ def test_sklearn_xgboost_sip_model_registration(
252253
prefix = "sip"
253254
bucket_name = sagemaker_session.default_bucket()
254255
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
255-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
256+
instance_type = "ml.m5.xlarge"
256257

258+
# The instance_type should not be a pipeline variable
259+
# since it is used to retrieve image_uri in compile time (PySDK)
257260
sklearn_processor = SKLearnProcessor(
258261
role=role,
259262
instance_type=instance_type,
@@ -324,6 +327,8 @@ def test_sklearn_xgboost_sip_model_registration(
324327
source_dir = base_dir
325328
code_location = "s3://{0}/{1}/code".format(bucket_name, prefix)
326329

330+
# If image_uri is not provided, the instance_type should not be a pipeline variable
331+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
327332
estimator = XGBoost(
328333
entry_point=entry_point,
329334
source_dir=source_dir,
@@ -409,7 +414,6 @@ def test_sklearn_xgboost_sip_model_registration(
409414
train_data_path_param,
410415
val_data_path_param,
411416
model_path_param,
412-
instance_type,
413417
instance_count,
414418
output_path_param,
415419
],
@@ -455,7 +459,7 @@ def test_model_registration_with_drift_check_baselines(
455459
pipeline_name,
456460
):
457461
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
458-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
462+
instance_type = "ml.m5.xlarge"
459463

460464
# upload model data to s3
461465
model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz")
@@ -543,6 +547,9 @@ def test_model_registration_with_drift_check_baselines(
543547
),
544548
)
545549
customer_metadata_properties = {"key1": "value1"}
550+
551+
# If image_uri is not provided, the instance_type should not be a pipeline variable
552+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
546553
estimator = XGBoost(
547554
entry_point="training.py",
548555
source_dir=os.path.join(DATA_DIR, "sip"),
@@ -572,7 +579,6 @@ def test_model_registration_with_drift_check_baselines(
572579
parameters=[
573580
model_uri_param,
574581
metrics_uri_param,
575-
instance_type,
576582
instance_count,
577583
],
578584
steps=[step_register],
@@ -660,9 +666,11 @@ def test_model_registration_with_model_repack(
660666
inputs = TrainingInput(s3_data=input_path)
661667

662668
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
663-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
669+
instance_type = "ml.m5.xlarge"
664670
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
665671

672+
# If image_uri is not provided, the instance_type should not be a pipeline variable
673+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
666674
pytorch_estimator = PyTorch(
667675
entry_point=entry_point,
668676
role=role,
@@ -717,7 +725,7 @@ def test_model_registration_with_model_repack(
717725

718726
pipeline = Pipeline(
719727
name=pipeline_name,
720-
parameters=[good_enough_input, instance_count, instance_type],
728+
parameters=[good_enough_input, instance_count],
721729
steps=[step_cond],
722730
sagemaker_session=sagemaker_session,
723731
)
@@ -760,8 +768,10 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
760768
inputs = TrainingInput(s3_data=input_path)
761769

762770
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
763-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
771+
instance_type = "ml.m5.xlarge"
764772

773+
# If image_uri is not provided, the instance_type should not be a pipeline variable
774+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
765775
tensorflow_estimator = TensorFlow(
766776
entry_point=entry_point,
767777
role=role,
@@ -802,10 +812,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
802812

803813
pipeline = Pipeline(
804814
name=pipeline_name,
805-
parameters=[
806-
instance_count,
807-
instance_type,
808-
],
815+
parameters=[instance_count],
809816
steps=[step_train, step_register_model],
810817
sagemaker_session=sagemaker_session,
811818
)

tests/integ/sagemaker/workflow/test_retry.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
DatasetDefinition,
2727
AthenaDatasetDefinition,
2828
)
29-
from sagemaker.workflow.parameters import (
30-
ParameterInteger,
31-
ParameterString,
32-
)
29+
from sagemaker.workflow.parameters import ParameterInteger
3330
from sagemaker.pytorch.estimator import PyTorch
3431
from sagemaker.workflow.pipeline import Pipeline
3532
from sagemaker.workflow.retry import (
@@ -183,9 +180,11 @@ def test_model_registration_with_model_repack(
183180
inputs = TrainingInput(s3_data=input_path)
184181

185182
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
186-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
183+
instance_type = "ml.m5.xlarge"
187184
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
188185

186+
# If image_uri is not provided, the instance_type should not be a pipeline variable
187+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
189188
pytorch_estimator = PyTorch(
190189
entry_point=entry_point,
191190
role=role,
@@ -247,7 +246,7 @@ def test_model_registration_with_model_repack(
247246

248247
pipeline = Pipeline(
249248
name=pipeline_name,
250-
parameters=[good_enough_input, instance_count, instance_type],
249+
parameters=[good_enough_input, instance_count],
251250
steps=[step_cond],
252251
sagemaker_session=sagemaker_session,
253252
)

tests/integ/sagemaker/workflow/test_training_steps.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_training_job_with_debugger_and_profiler(
5959
pytorch_training_latest_py_version,
6060
):
6161
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
62-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
62+
instance_type = "ml.m5.xlarge"
6363

6464
rules = [
6565
Rule.sagemaker(rule_configs.vanishing_gradient()),
@@ -78,6 +78,8 @@ def test_training_job_with_debugger_and_profiler(
7878
)
7979
inputs = TrainingInput(s3_data=input_path)
8080

81+
# If image_uri is not provided, the instance_type should not be a pipeline variable
82+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
8183
pytorch_estimator = PyTorch(
8284
entry_point=script_path,
8385
role="SageMakerRole",
@@ -98,7 +100,7 @@ def test_training_job_with_debugger_and_profiler(
98100

99101
pipeline = Pipeline(
100102
name=pipeline_name,
101-
parameters=[instance_count, instance_type],
103+
parameters=[instance_count],
102104
steps=[step_train],
103105
sagemaker_session=sagemaker_session,
104106
)

tests/integ/sagemaker/workflow/test_tuning_steps.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ def test_tuning_single_algo(
9393
inputs = TrainingInput(s3_data=input_path)
9494

9595
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
96-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
96+
instance_type = "ml.m5.xlarge"
9797

98+
# If image_uri is not provided, the instance_type should not be a pipeline variable
99+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
98100
pytorch_estimator = PyTorch(
99101
entry_point=entry_point,
100102
role=role,
@@ -168,7 +170,7 @@ def test_tuning_single_algo(
168170

169171
pipeline = Pipeline(
170172
name=pipeline_name,
171-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
173+
parameters=[instance_count, min_batch_size, max_batch_size],
172174
steps=[step_tune, step_best_model, step_second_best_model],
173175
sagemaker_session=sagemaker_session,
174176
)
@@ -225,10 +227,12 @@ def test_tuning_multi_algos(
225227
)
226228

227229
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
228-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
230+
instance_type = "ml.m5.xlarge"
229231

230232
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
231233

234+
# The instance_type should not be a pipeline variable
235+
# since it is used to retrieve image_uri in compile time (PySDK)
232236
sklearn_processor = SKLearnProcessor(
233237
framework_version="0.20.0",
234238
instance_type=instance_type,
@@ -263,6 +267,8 @@ def test_tuning_multi_algos(
263267
json_get_hp = JsonGet(
264268
step_name=step_process.name, property_file=property_file, json_path="train_size"
265269
)
270+
# If image_uri is not provided, the instance_type should not be a pipeline variable
271+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
266272
pytorch_estimator = PyTorch(
267273
entry_point=entry_point,
268274
role=role,
@@ -311,7 +317,7 @@ def test_tuning_multi_algos(
311317

312318
pipeline = Pipeline(
313319
name=pipeline_name,
314-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
320+
parameters=[instance_count, min_batch_size, max_batch_size, static_hp_1],
315321
steps=[step_process, step_tune],
316322
sagemaker_session=sagemaker_session,
317323
)

tests/integ/sagemaker/workflow/test_workflow.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,14 @@ def test_three_step_definition(
157157
athena_dataset_definition,
158158
):
159159
framework_version = "0.20.0"
160-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
160+
instance_type = "ml.m5.xlarge"
161161
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
162162
output_prefix = ParameterString(name="OutputPrefix", default_value="output")
163163

164164
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
165165

166+
# The instance_type should not be a pipeline variable
167+
# since it is used to retrieve image_uri in compile time (PySDK)
166168
sklearn_processor = SKLearnProcessor(
167169
framework_version=framework_version,
168170
instance_type=instance_type,
@@ -200,6 +202,8 @@ def test_three_step_definition(
200202
code=os.path.join(script_dir, "preprocessing.py"),
201203
)
202204

205+
# If image_uri is not provided, the instance_type should not be a pipeline variable
206+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
203207
sklearn_train = SKLearn(
204208
framework_version=framework_version,
205209
entry_point=os.path.join(script_dir, "train.py"),
@@ -239,7 +243,7 @@ def test_three_step_definition(
239243

240244
pipeline = Pipeline(
241245
name=pipeline_name,
242-
parameters=[instance_type, instance_count, output_prefix],
246+
parameters=[instance_count, output_prefix],
243247
steps=[step_process, step_train, step_model],
244248
sagemaker_session=sagemaker_session,
245249
)
@@ -249,13 +253,6 @@ def test_three_step_definition(
249253

250254
assert set(tuple(param.items()) for param in definition["Parameters"]) == set(
251255
[
252-
tuple(
253-
{
254-
"Name": "InstanceType",
255-
"Type": "String",
256-
"DefaultValue": "ml.m5.xlarge",
257-
}.items()
258-
),
259256
tuple({"Name": "InstanceCount", "Type": "Integer", "DefaultValue": 1}.items()),
260257
tuple(
261258
{
@@ -300,14 +297,14 @@ def test_three_step_definition(
300297
]
301298
)
302299
assert processing_args["ProcessingResources"]["ClusterConfig"] == {
303-
"InstanceType": {"Get": "Parameters.InstanceType"},
300+
"InstanceType": "ml.m5.xlarge",
304301
"InstanceCount": {"Get": "Parameters.InstanceCount"},
305302
"VolumeSizeInGB": 30,
306303
}
307304

308305
assert training_args["ResourceConfig"] == {
309306
"InstanceCount": 1,
310-
"InstanceType": {"Get": "Parameters.InstanceType"},
307+
"InstanceType": "ml.m5.xlarge",
311308
"VolumeSizeInGB": 30,
312309
}
313310
assert training_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == {
@@ -340,10 +337,12 @@ def test_steps_with_map_params_pipeline(
340337
):
341338
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
342339
framework_version = "0.20.0"
343-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
340+
instance_type = "ml.m5.xlarge"
344341
output_prefix = ParameterString(name="OutputPrefix", default_value="output")
345342
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
346343

344+
# The instance_type should not be a pipeline variable
345+
# since it is used to retrieve image_uri in compile time (PySDK)
347346
sklearn_processor = SKLearnProcessor(
348347
framework_version=framework_version,
349348
instance_type=instance_type,
@@ -381,6 +380,8 @@ def test_steps_with_map_params_pipeline(
381380
code=os.path.join(script_dir, "preprocessing.py"),
382381
)
383382

383+
# If image_uri is not provided, the instance_type should not be a pipeline variable
384+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
384385
sklearn_train = SKLearn(
385386
framework_version=framework_version,
386387
entry_point=os.path.join(script_dir, "train.py"),
@@ -437,7 +438,7 @@ def test_steps_with_map_params_pipeline(
437438

438439
pipeline = Pipeline(
439440
name=pipeline_name,
440-
parameters=[instance_type, instance_count, output_prefix],
441+
parameters=[instance_count, output_prefix],
441442
steps=[step_process, step_train, step_cond],
442443
sagemaker_session=sagemaker_session,
443444
)
@@ -1031,8 +1032,10 @@ def test_model_registration_with_tuning_model(
10311032
inputs = TrainingInput(s3_data=input_path)
10321033

10331034
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
1034-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
1035+
instance_type = "ml.m5.xlarge"
10351036

1037+
# If image_uri is not provided, the instance_type should not be a pipeline variable
1038+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
10361039
pytorch_estimator = PyTorch(
10371040
entry_point=entry_point,
10381041
role=role,
@@ -1083,7 +1086,7 @@ def test_model_registration_with_tuning_model(
10831086

10841087
pipeline = Pipeline(
10851088
name=pipeline_name,
1086-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
1089+
parameters=[instance_count, min_batch_size, max_batch_size],
10871090
steps=[step_tune, step_register_best],
10881091
sagemaker_session=sagemaker_session,
10891092
)

0 commit comments

Comments
 (0)