@@ -84,10 +84,12 @@ def test_conditional_pytorch_training_model_registration(
84
84
inputs = TrainingInput (s3_data = input_path )
85
85
86
86
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
87
- instance_type = ParameterString ( name = "InstanceType" , default_value = " ml.m5.xlarge")
87
+ instance_type = " ml.m5.xlarge"
88
88
good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
89
89
in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
90
90
91
+ # If image_uri is not provided, the instance_type should not be a pipeline variable
92
+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
91
93
pytorch_estimator = PyTorch (
92
94
entry_point = entry_point ,
93
95
role = role ,
@@ -146,7 +148,6 @@ def test_conditional_pytorch_training_model_registration(
146
148
in_condition_input ,
147
149
good_enough_input ,
148
150
instance_count ,
149
- instance_type ,
150
151
],
151
152
steps = [step_cond ],
152
153
sagemaker_session = sagemaker_session ,
@@ -252,8 +253,10 @@ def test_sklearn_xgboost_sip_model_registration(
252
253
prefix = "sip"
253
254
bucket_name = sagemaker_session .default_bucket ()
254
255
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
255
- instance_type = ParameterString ( name = "InstanceType" , default_value = " ml.m5.xlarge")
256
+ instance_type = " ml.m5.xlarge"
256
257
258
+ # The instance_type should not be a pipeline variable
259
+ # since it is used to retrieve image_uri in compile time (PySDK)
257
260
sklearn_processor = SKLearnProcessor (
258
261
role = role ,
259
262
instance_type = instance_type ,
@@ -324,6 +327,8 @@ def test_sklearn_xgboost_sip_model_registration(
324
327
source_dir = base_dir
325
328
code_location = "s3://{0}/{1}/code" .format (bucket_name , prefix )
326
329
330
+ # If image_uri is not provided, the instance_type should not be a pipeline variable
331
+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
327
332
estimator = XGBoost (
328
333
entry_point = entry_point ,
329
334
source_dir = source_dir ,
@@ -409,7 +414,6 @@ def test_sklearn_xgboost_sip_model_registration(
409
414
train_data_path_param ,
410
415
val_data_path_param ,
411
416
model_path_param ,
412
- instance_type ,
413
417
instance_count ,
414
418
output_path_param ,
415
419
],
@@ -455,7 +459,7 @@ def test_model_registration_with_drift_check_baselines(
455
459
pipeline_name ,
456
460
):
457
461
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
458
- instance_type = ParameterString ( name = "InstanceType" , default_value = " ml.m5.xlarge")
462
+ instance_type = " ml.m5.xlarge"
459
463
460
464
# upload model data to s3
461
465
model_local_path = os .path .join (DATA_DIR , "mxnet_mnist/model.tar.gz" )
@@ -543,6 +547,9 @@ def test_model_registration_with_drift_check_baselines(
543
547
),
544
548
)
545
549
customer_metadata_properties = {"key1" : "value1" }
550
+
551
+ # If image_uri is not provided, the instance_type should not be a pipeline variable
552
+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
546
553
estimator = XGBoost (
547
554
entry_point = "training.py" ,
548
555
source_dir = os .path .join (DATA_DIR , "sip" ),
@@ -572,7 +579,6 @@ def test_model_registration_with_drift_check_baselines(
572
579
parameters = [
573
580
model_uri_param ,
574
581
metrics_uri_param ,
575
- instance_type ,
576
582
instance_count ,
577
583
],
578
584
steps = [step_register ],
@@ -660,9 +666,11 @@ def test_model_registration_with_model_repack(
660
666
inputs = TrainingInput (s3_data = input_path )
661
667
662
668
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
663
- instance_type = ParameterString ( name = "InstanceType" , default_value = " ml.m5.xlarge")
669
+ instance_type = " ml.m5.xlarge"
664
670
good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
665
671
672
+ # If image_uri is not provided, the instance_type should not be a pipeline variable
673
+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
666
674
pytorch_estimator = PyTorch (
667
675
entry_point = entry_point ,
668
676
role = role ,
@@ -717,7 +725,7 @@ def test_model_registration_with_model_repack(
717
725
718
726
pipeline = Pipeline (
719
727
name = pipeline_name ,
720
- parameters = [good_enough_input , instance_count , instance_type ],
728
+ parameters = [good_enough_input , instance_count ],
721
729
steps = [step_cond ],
722
730
sagemaker_session = sagemaker_session ,
723
731
)
0 commit comments