@@ -83,16 +83,15 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
83
83
inputs = TrainingInput (s3_data = input_path )
84
84
85
85
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
86
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
86
87
87
- # If image_uri is not provided, the instance_type should not be a pipeline variable
88
- # since instance_type is used to retrieve image_uri in compile time (PySDK)
89
88
pytorch_estimator = PyTorch (
90
89
entry_point = entry_point ,
91
90
role = role ,
92
91
framework_version = "1.5.0" ,
93
92
py_version = "py3" ,
94
93
instance_count = instance_count ,
95
- instance_type = "ml.m5.xlarge" ,
94
+ instance_type = instance_type ,
96
95
sagemaker_session = pipeline_session ,
97
96
)
98
97
train_step_args = pytorch_estimator .fit (inputs = inputs )
@@ -141,7 +140,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
141
140
)
142
141
pipeline = Pipeline (
143
142
name = pipeline_name ,
144
- parameters = [instance_count ],
143
+ parameters = [instance_count , instance_type ],
145
144
steps = [step_train , step_model_regis , step_model_create , step_fail ],
146
145
sagemaker_session = pipeline_session ,
147
146
)
@@ -204,16 +203,15 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
204
203
inputs = TrainingInput (s3_data = input_path )
205
204
206
205
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
206
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
207
207
208
- # If image_uri is not provided, the instance_type should not be a pipeline variable
209
- # since instance_type is used to retrieve image_uri in compile time (PySDK)
210
208
pytorch_estimator = PyTorch (
211
209
entry_point = entry_point ,
212
210
role = role ,
213
211
framework_version = "1.5.0" ,
214
212
py_version = "py3" ,
215
213
instance_count = instance_count ,
216
- instance_type = "ml.m5.xlarge" ,
214
+ instance_type = instance_type ,
217
215
sagemaker_session = pipeline_session ,
218
216
output_kms_key = kms_key ,
219
217
)
@@ -269,7 +267,7 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
269
267
)
270
268
pipeline = Pipeline (
271
269
name = pipeline_name ,
272
- parameters = [instance_count ],
270
+ parameters = [instance_count , instance_type ],
273
271
steps = [step_train , step_model_regis , step_model_create , step_fail ],
274
272
sagemaker_session = pipeline_session ,
275
273
)
@@ -402,6 +400,7 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
402
400
pipeline_name ,
403
401
):
404
402
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
403
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
405
404
406
405
# upload model data to s3
407
406
model_local_path = os .path .join (DATA_DIR , "mxnet_mnist/model.tar.gz" )
@@ -489,12 +488,10 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
489
488
),
490
489
)
491
490
customer_metadata_properties = {"key1" : "value1" }
492
- # If image_uri is not provided, the instance_type should not be a pipeline variable
493
- # since instance_type is used to retrieve image_uri in compile time (PySDK)
494
491
estimator = XGBoost (
495
492
entry_point = "training.py" ,
496
493
source_dir = os .path .join (DATA_DIR , "sip" ),
497
- instance_type = "ml.m5.xlarge" ,
494
+ instance_type = instance_type ,
498
495
instance_count = instance_count ,
499
496
framework_version = "0.90-2" ,
500
497
sagemaker_session = pipeline_session ,
@@ -527,6 +524,7 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
527
524
parameters = [
528
525
model_uri_param ,
529
526
metrics_uri_param ,
527
+ instance_type ,
530
528
instance_count ,
531
529
],
532
530
steps = [step_model_register ],
@@ -608,14 +606,13 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
608
606
)
609
607
inputs = TrainingInput (s3_data = input_path )
610
608
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
609
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
611
610
612
- # If image_uri is not provided, the instance_type should not be a pipeline variable
613
- # since instance_type is used to retrieve image_uri in compile time (PySDK)
614
611
tensorflow_estimator = TensorFlow (
615
612
entry_point = entry_point ,
616
613
role = role ,
617
614
instance_count = instance_count ,
618
- instance_type = "ml.m5.xlarge" ,
615
+ instance_type = instance_type ,
619
616
framework_version = tf_full_version ,
620
617
py_version = tf_full_py_version ,
621
618
sagemaker_session = pipeline_session ,
@@ -648,7 +645,10 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
648
645
)
649
646
pipeline = Pipeline (
650
647
name = pipeline_name ,
651
- parameters = [instance_count ],
648
+ parameters = [
649
+ instance_count ,
650
+ instance_type ,
651
+ ],
652
652
steps = [step_train , step_register_model ],
653
653
sagemaker_session = pipeline_session ,
654
654
)
0 commit comments