Skip to content

Commit d972439

Browse files
author
Dewen Qi
committed
fix: Prevent passing PipelineVariable object into image_uris.retrieve
1 parent 6ad4e45 commit d972439

12 files changed

+82
-31
lines changed

src/sagemaker/image_uris.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.jumpstart.utils import is_jumpstart_model_input
2424
from sagemaker.spark import defaults
2525
from sagemaker.jumpstart import artifacts
26+
from sagemaker.workflow import is_pipeline_variable
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -104,11 +105,17 @@ def retrieve(
104105
105106
Raises:
106107
NotImplementedError: If the scope is not supported.
107-
ValueError: If the combination of arguments specified is not supported.
108+
ValueError: If the combination of arguments specified is not supported or
109+
any PipelineVariable object is passed in.
108110
VulnerableJumpStartModelError: If any of the dependencies required by the script have
109111
known security vulnerabilities.
110112
DeprecatedJumpStartModelError: If the version of the model is deprecated.
111113
"""
114+
args = dict(locals())
115+
for name, val in args.items():
116+
if is_pipeline_variable(val):
117+
raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val)))
118+
112119
if is_jumpstart_model_input(model_id, model_version):
113120
return artifacts._retrieve_image_uri(
114121
model_id,

src/sagemaker/workflow/entities.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def __add__(self, other: Union[Expression, PrimitiveType]):
7878

7979
def __str__(self):
8080
"""Override built-in String function for PipelineVariable"""
81-
raise TypeError("Pipeline variables do not support __str__ operation.")
81+
raise TypeError(
82+
"Pipeline variables do not support __str__ operation. "
83+
"Please use `.to_string()` to convert it to string type in execution time"
84+
"or use `.expr` to translate it to Json for display purpose in Python SDK."
85+
)
8286

8387
def __int__(self):
8488
"""Override built-in Integer function for PipelineVariable"""

tests/integ/sagemaker/workflow/test_model_registration.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def test_conditional_pytorch_training_model_registration(
8484
inputs = TrainingInput(s3_data=input_path)
8585

8686
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
87-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
87+
instance_type = "ml.m5.xlarge"
8888
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
8989
in_condition_input = ParameterString(name="Foo", default_value="Foo")
9090

91+
# If image_uri is not provided, the instance_type should not be a pipeline variable
92+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
9193
pytorch_estimator = PyTorch(
9294
entry_point=entry_point,
9395
role=role,
@@ -146,7 +148,6 @@ def test_conditional_pytorch_training_model_registration(
146148
in_condition_input,
147149
good_enough_input,
148150
instance_count,
149-
instance_type,
150151
],
151152
steps=[step_cond],
152153
sagemaker_session=sagemaker_session,
@@ -252,8 +253,10 @@ def test_sklearn_xgboost_sip_model_registration(
252253
prefix = "sip"
253254
bucket_name = sagemaker_session.default_bucket()
254255
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
255-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
256+
instance_type = "ml.m5.xlarge"
256257

258+
# The instance_type should not be a pipeline variable
259+
# since it is used to retrieve image_uri in compile time (PySDK)
257260
sklearn_processor = SKLearnProcessor(
258261
role=role,
259262
instance_type=instance_type,
@@ -324,6 +327,8 @@ def test_sklearn_xgboost_sip_model_registration(
324327
source_dir = base_dir
325328
code_location = "s3://{0}/{1}/code".format(bucket_name, prefix)
326329

330+
# If image_uri is not provided, the instance_type should not be a pipeline variable
331+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
327332
estimator = XGBoost(
328333
entry_point=entry_point,
329334
source_dir=source_dir,
@@ -409,7 +414,6 @@ def test_sklearn_xgboost_sip_model_registration(
409414
train_data_path_param,
410415
val_data_path_param,
411416
model_path_param,
412-
instance_type,
413417
instance_count,
414418
output_path_param,
415419
],
@@ -455,7 +459,7 @@ def test_model_registration_with_drift_check_baselines(
455459
pipeline_name,
456460
):
457461
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
458-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
462+
instance_type = "ml.m5.xlarge"
459463

460464
# upload model data to s3
461465
model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz")
@@ -543,6 +547,9 @@ def test_model_registration_with_drift_check_baselines(
543547
),
544548
)
545549
customer_metadata_properties = {"key1": "value1"}
550+
551+
# If image_uri is not provided, the instance_type should not be a pipeline variable
552+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
546553
estimator = XGBoost(
547554
entry_point="training.py",
548555
source_dir=os.path.join(DATA_DIR, "sip"),
@@ -572,7 +579,6 @@ def test_model_registration_with_drift_check_baselines(
572579
parameters=[
573580
model_uri_param,
574581
metrics_uri_param,
575-
instance_type,
576582
instance_count,
577583
],
578584
steps=[step_register],
@@ -660,9 +666,11 @@ def test_model_registration_with_model_repack(
660666
inputs = TrainingInput(s3_data=input_path)
661667

662668
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
663-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
669+
instance_type = "ml.m5.xlarge"
664670
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
665671

672+
# If image_uri is not provided, the instance_type should not be a pipeline variable
673+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
666674
pytorch_estimator = PyTorch(
667675
entry_point=entry_point,
668676
role=role,
@@ -717,7 +725,7 @@ def test_model_registration_with_model_repack(
717725

718726
pipeline = Pipeline(
719727
name=pipeline_name,
720-
parameters=[good_enough_input, instance_count, instance_type],
728+
parameters=[good_enough_input, instance_count],
721729
steps=[step_cond],
722730
sagemaker_session=sagemaker_session,
723731
)

tests/integ/sagemaker/workflow/test_retry.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
DatasetDefinition,
2727
AthenaDatasetDefinition,
2828
)
29-
from sagemaker.workflow.parameters import (
30-
ParameterInteger,
31-
ParameterString,
32-
)
29+
from sagemaker.workflow.parameters import ParameterInteger
3330
from sagemaker.pytorch.estimator import PyTorch
3431
from sagemaker.workflow.pipeline import Pipeline
3532
from sagemaker.workflow.retry import (
@@ -183,9 +180,11 @@ def test_model_registration_with_model_repack(
183180
inputs = TrainingInput(s3_data=input_path)
184181

185182
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
186-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
183+
instance_type = "ml.m5.xlarge"
187184
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
188185

186+
# If image_uri is not provided, the instance_type should not be a pipeline variable
187+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
189188
pytorch_estimator = PyTorch(
190189
entry_point=entry_point,
191190
role=role,
@@ -247,7 +246,7 @@ def test_model_registration_with_model_repack(
247246

248247
pipeline = Pipeline(
249248
name=pipeline_name,
250-
parameters=[good_enough_input, instance_count, instance_type],
249+
parameters=[good_enough_input, instance_count],
251250
steps=[step_cond],
252251
sagemaker_session=sagemaker_session,
253252
)

tests/integ/sagemaker/workflow/test_training_steps.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_training_job_with_debugger_and_profiler(
5959
pytorch_training_latest_py_version,
6060
):
6161
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
62-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
62+
instance_type = "ml.m5.xlarge"
6363

6464
rules = [
6565
Rule.sagemaker(rule_configs.vanishing_gradient()),
@@ -78,6 +78,8 @@ def test_training_job_with_debugger_and_profiler(
7878
)
7979
inputs = TrainingInput(s3_data=input_path)
8080

81+
# If image_uri is not provided, the instance_type should not be a pipeline variable
82+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
8183
pytorch_estimator = PyTorch(
8284
entry_point=script_path,
8385
role="SageMakerRole",
@@ -98,7 +100,7 @@ def test_training_job_with_debugger_and_profiler(
98100

99101
pipeline = Pipeline(
100102
name=pipeline_name,
101-
parameters=[instance_count, instance_type],
103+
parameters=[instance_count],
102104
steps=[step_train],
103105
sagemaker_session=sagemaker_session,
104106
)

tests/integ/sagemaker/workflow/test_tuning_steps.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ def test_tuning_single_algo(
9393
inputs = TrainingInput(s3_data=input_path)
9494

9595
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
96-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
96+
instance_type = "ml.m5.xlarge"
9797

98+
# If image_uri is not provided, the instance_type should not be a pipeline variable
99+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
98100
pytorch_estimator = PyTorch(
99101
entry_point=entry_point,
100102
role=role,
@@ -168,7 +170,7 @@ def test_tuning_single_algo(
168170

169171
pipeline = Pipeline(
170172
name=pipeline_name,
171-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
173+
parameters=[instance_count, min_batch_size, max_batch_size],
172174
steps=[step_tune, step_best_model, step_second_best_model],
173175
sagemaker_session=sagemaker_session,
174176
)
@@ -225,10 +227,12 @@ def test_tuning_multi_algos(
225227
)
226228

227229
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
228-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
230+
instance_type = "ml.m5.xlarge"
229231

230232
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
231233

234+
# The instance_type should not be a pipeline variable
235+
# since it is used to retrieve image_uri in compile time (PySDK)
232236
sklearn_processor = SKLearnProcessor(
233237
framework_version="0.20.0",
234238
instance_type=instance_type,
@@ -263,6 +267,8 @@ def test_tuning_multi_algos(
263267
json_get_hp = JsonGet(
264268
step_name=step_process.name, property_file=property_file, json_path="train_size"
265269
)
270+
# If image_uri is not provided, the instance_type should not be a pipeline variable
271+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
266272
pytorch_estimator = PyTorch(
267273
entry_point=entry_point,
268274
role=role,
@@ -311,7 +317,7 @@ def test_tuning_multi_algos(
311317

312318
pipeline = Pipeline(
313319
name=pipeline_name,
314-
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
320+
parameters=[instance_count, min_batch_size, max_batch_size],
315321
steps=[step_process, step_tune],
316322
sagemaker_session=sagemaker_session,
317323
)

tests/integ/sagemaker/workflow/test_workflow.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,14 @@ def test_three_step_definition(
157157
athena_dataset_definition,
158158
):
159159
framework_version = "0.20.0"
160-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
160+
instance_type = "ml.m5.xlarge"
161161
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
162162
output_prefix = ParameterString(name="OutputPrefix", default_value="output")
163163

164164
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
165165

166+
# The instance_type should not be a pipeline variable
167+
# since it is used to retrieve image_uri in compile time (PySDK)
166168
sklearn_processor = SKLearnProcessor(
167169
framework_version=framework_version,
168170
instance_type=instance_type,
@@ -200,6 +202,8 @@ def test_three_step_definition(
200202
code=os.path.join(script_dir, "preprocessing.py"),
201203
)
202204

205+
# If image_uri is not provided, the instance_type should not be a pipeline variable
206+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
203207
sklearn_train = SKLearn(
204208
framework_version=framework_version,
205209
entry_point=os.path.join(script_dir, "train.py"),
@@ -239,7 +243,7 @@ def test_three_step_definition(
239243

240244
pipeline = Pipeline(
241245
name=pipeline_name,
242-
parameters=[instance_type, instance_count, output_prefix],
246+
parameters=[instance_count, output_prefix],
243247
steps=[step_process, step_train, step_model],
244248
sagemaker_session=sagemaker_session,
245249
)
@@ -340,10 +344,12 @@ def test_steps_with_map_params_pipeline(
340344
):
341345
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
342346
framework_version = "0.20.0"
343-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
347+
instance_type = "ml.m5.xlarge"
344348
output_prefix = ParameterString(name="OutputPrefix", default_value="output")
345349
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
346350

351+
# The instance_type should not be a pipeline variable
352+
# since it is used to retrieve image_uri in compile time (PySDK)
347353
sklearn_processor = SKLearnProcessor(
348354
framework_version=framework_version,
349355
instance_type=instance_type,
@@ -381,6 +387,8 @@ def test_steps_with_map_params_pipeline(
381387
code=os.path.join(script_dir, "preprocessing.py"),
382388
)
383389

390+
# If image_uri is not provided, the instance_type should not be a pipeline variable
391+
# since instance_type is used to retrieve image_uri in compile time (PySDK)
384392
sklearn_train = SKLearn(
385393
framework_version=framework_version,
386394
entry_point=os.path.join(script_dir, "train.py"),
@@ -437,7 +445,7 @@ def test_steps_with_map_params_pipeline(
437445

438446
pipeline = Pipeline(
439447
name=pipeline_name,
440-
parameters=[instance_type, instance_count, output_prefix],
448+
parameters=[instance_count, output_prefix],
441449
steps=[step_process, step_train, step_cond],
442450
sagemaker_session=sagemaker_session,
443451
)

tests/unit/sagemaker/image_uris/test_retrieve.py

+17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mock import patch
2020

2121
from sagemaker import image_uris
22+
from sagemaker.workflow.parameters import ParameterString
2223

2324
BASE_CONFIG = {
2425
"processors": ["cpu", "gpu"],
@@ -717,3 +718,19 @@ def test_retrieve_huggingface(config_for_framework):
717718
"564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
718719
"1.6.0-transformers4.3.1-gpu-py37-cu110-ubuntu18.04" == pt_new_version
719720
)
721+
722+
723+
def test_retrieve_with_pipeline_variable():
724+
with pytest.raises(Exception) as error:
725+
image_uris.retrieve(
726+
framework="tensorflow",
727+
version="1.15",
728+
py_version="py3",
729+
instance_type=ParameterString(
730+
name="TrainingInstanceType",
731+
default_value="ml.m5.xlarge",
732+
),
733+
region="us-east-1",
734+
image_scope="training",
735+
)
736+
assert "instance_type should not be a pipeline variable" in str(error.value)

tests/unit/sagemaker/workflow/test_execution_variables.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_implicit_value():
3333

3434
with pytest.raises(TypeError) as error:
3535
str(var)
36-
assert str(error.value) == "Pipeline variables do not support __str__ operation."
36+
assert "Pipeline variables do not support __str__ operation." in str(error.value)
3737

3838
with pytest.raises(TypeError) as error:
3939
int(var)

tests/unit/sagemaker/workflow/test_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_implicit_value_on_join():
8181

8282
with pytest.raises(TypeError) as error:
8383
str(func)
84-
assert str(error.value) == "Pipeline variables do not support __str__ operation."
84+
assert "Pipeline variables do not support __str__ operation." in str(error.value)
8585

8686
with pytest.raises(TypeError) as error:
8787
int(func)
@@ -189,7 +189,7 @@ def test_implicit_value_on_json_get():
189189

190190
with pytest.raises(TypeError) as error:
191191
str(func)
192-
assert str(error.value) == "Pipeline variables do not support __str__ operation."
192+
assert "Pipeline variables do not support __str__ operation." in str(error.value)
193193

194194
with pytest.raises(TypeError) as error:
195195
int(func)

tests/unit/sagemaker/workflow/test_parameters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_parameter_to_string_and_string_implicit_value():
7676
with pytest.raises(TypeError) as error:
7777
str(param)
7878

79-
assert str(error.value) == "Pipeline variables do not support __str__ operation."
79+
assert "Pipeline variables do not support __str__ operation." in str(error.value)
8080

8181

8282
def test_parameter_integer_implicit_value():

tests/unit/sagemaker/workflow/test_properties.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_implicit_value():
111111

112112
with pytest.raises(TypeError) as error:
113113
str(prop.CreationTime)
114-
assert str(error.value) == "Pipeline variables do not support __str__ operation."
114+
assert "Pipeline variables do not support __str__ operation." in str(error.value)
115115

116116
with pytest.raises(TypeError) as error:
117117
int(prop.CreationTime)

0 commit comments

Comments
 (0)