@@ -1075,7 +1075,7 @@ def test_conditional_pytorch_training_model_registration(
1075
1075
pass
1076
1076
1077
1077
1078
- def test_tuning (
1078
+ def test_tuning_single_algo (
1079
1079
sagemaker_session ,
1080
1080
role ,
1081
1081
cpu_instance_type ,
@@ -1098,14 +1098,17 @@ def test_tuning(
1098
1098
role = role ,
1099
1099
framework_version = "1.5.0" ,
1100
1100
py_version = "py3" ,
1101
- instance_count = 1 ,
1102
- instance_type = "ml.m5.xlarge" ,
1101
+ instance_count = instance_count ,
1102
+ instance_type = instance_type ,
1103
1103
sagemaker_session = sagemaker_session ,
1104
1104
enable_sagemaker_metrics = True ,
1105
+ max_retry_attempts = 3 ,
1105
1106
)
1106
1107
1108
+ min_batch_size = ParameterString (name = "MinBatchSize" , default_value = "64" )
1109
+ max_batch_size = ParameterString (name = "MaxBatchSize" , default_value = "128" )
1107
1110
hyperparameter_ranges = {
1108
- "batch-size" : IntegerParameter (64 , 128 ),
1111
+ "batch-size" : IntegerParameter (min_batch_size , max_batch_size ),
1109
1112
}
1110
1113
1111
1114
tuner = HyperparameterTuner (
@@ -1161,7 +1164,7 @@ def test_tuning(
1161
1164
1162
1165
pipeline = Pipeline (
1163
1166
name = pipeline_name ,
1164
- parameters = [instance_count , instance_type ],
1167
+ parameters = [instance_count , instance_type , min_batch_size , max_batch_size ],
1165
1168
steps = [step_tune , step_best_model , step_second_best_model ],
1166
1169
sagemaker_session = sagemaker_session ,
1167
1170
)
@@ -1185,6 +1188,93 @@ def test_tuning(
1185
1188
pass
1186
1189
1187
1190
1191
+ def test_tuning_multi_algos (
1192
+ sagemaker_session ,
1193
+ role ,
1194
+ cpu_instance_type ,
1195
+ pipeline_name ,
1196
+ region_name ,
1197
+ ):
1198
+ base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
1199
+ entry_point = os .path .join (base_dir , "mnist.py" )
1200
+ input_path = sagemaker_session .upload_data (
1201
+ path = os .path .join (base_dir , "training" ),
1202
+ key_prefix = "integ-test-data/pytorch_mnist/training" ,
1203
+ )
1204
+
1205
+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
1206
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
1207
+
1208
+ pytorch_estimator = PyTorch (
1209
+ entry_point = entry_point ,
1210
+ role = role ,
1211
+ framework_version = "1.5.0" ,
1212
+ py_version = "py3" ,
1213
+ instance_count = instance_count ,
1214
+ instance_type = instance_type ,
1215
+ sagemaker_session = sagemaker_session ,
1216
+ enable_sagemaker_metrics = True ,
1217
+ max_retry_attempts = 3 ,
1218
+ )
1219
+
1220
+ min_batch_size = ParameterString (name = "MinBatchSize" , default_value = "64" )
1221
+ max_batch_size = ParameterString (name = "MaxBatchSize" , default_value = "128" )
1222
+
1223
+ tuner = HyperparameterTuner .create (
1224
+ estimator_dict = {
1225
+ "estimator-1" : pytorch_estimator ,
1226
+ "estimator-2" : pytorch_estimator ,
1227
+ },
1228
+ objective_metric_name_dict = {
1229
+ "estimator-1" : "test:acc" ,
1230
+ "estimator-2" : "test:acc" ,
1231
+ },
1232
+ hyperparameter_ranges_dict = {
1233
+ "estimator-1" : {"batch-size" : IntegerParameter (min_batch_size , max_batch_size )},
1234
+ "estimator-2" : {"batch-size" : IntegerParameter (min_batch_size , max_batch_size )},
1235
+ },
1236
+ metric_definitions_dict = {
1237
+ "estimator-1" : [{"Name" : "test:acc" , "Regex" : "Overall test accuracy: (.*?);" }],
1238
+ "estimator-2" : [{"Name" : "test:acc" , "Regex" : "Overall test accuracy: (.*?);" }],
1239
+ },
1240
+ )
1241
+ inputs = {
1242
+ "estimator-1" : TrainingInput (s3_data = input_path ),
1243
+ "estimator-2" : TrainingInput (s3_data = input_path ),
1244
+ }
1245
+
1246
+ step_tune = TuningStep (
1247
+ name = "my-tuning-step" ,
1248
+ tuner = tuner ,
1249
+ inputs = inputs ,
1250
+ )
1251
+
1252
+ pipeline = Pipeline (
1253
+ name = pipeline_name ,
1254
+ parameters = [instance_count , instance_type , min_batch_size , max_batch_size ],
1255
+ steps = [step_tune ],
1256
+ sagemaker_session = sagemaker_session ,
1257
+ )
1258
+
1259
+ try :
1260
+ response = pipeline .create (role )
1261
+ create_arn = response ["PipelineArn" ]
1262
+ assert re .match (
1263
+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1264
+ )
1265
+
1266
+ execution = pipeline .start (parameters = {})
1267
+ assert re .match (
1268
+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
1269
+ execution .arn ,
1270
+ )
1271
+ finally :
1272
+ try :
1273
+ pipeline .delete ()
1274
+ except Exception :
1275
+ pass
1276
+
1277
+
1188
1278
def test_mxnet_model_registration (
1189
1279
sagemaker_session ,
1190
1280
role ,
0 commit comments