@@ -441,6 +441,44 @@ def test_tuning_mxnet(sagemaker_session, mxnet_full_version):
441
441
predictor .predict (data )
442
442
443
443
444
+ @pytest .mark .canary_quick
445
+ def test_tuning_tf_script_mode (sagemaker_session ):
446
+ resource_path = os .path .join (DATA_DIR , 'tensorflow_mnist' )
447
+ script_path = os .path .join (resource_path , 'mnist.py' )
448
+
449
+ estimator = TensorFlow (entry_point = script_path ,
450
+ role = 'SageMakerRole' ,
451
+ train_instance_count = 1 ,
452
+ train_instance_type = 'ml.m4.xlarge' ,
453
+ script_mode = True ,
454
+ sagemaker_session = sagemaker_session ,
455
+ py_version = PYTHON_VERSION ,
456
+ framework_version = TensorFlow .LATEST_VERSION )
457
+
458
+ hyperparameter_ranges = {'epochs' : IntegerParameter (1 , 2 )}
459
+ objective_metric_name = 'accuracy'
460
+ metric_definitions = [{'Name' : objective_metric_name , 'Regex' : 'accuracy = ([0-9\\ .]+)' }]
461
+
462
+ tuner = HyperparameterTuner (estimator ,
463
+ objective_metric_name ,
464
+ hyperparameter_ranges ,
465
+ metric_definitions ,
466
+ max_jobs = 2 ,
467
+ max_parallel_jobs = 2 )
468
+
469
+ with timeout (minutes = TUNING_DEFAULT_TIMEOUT_MINUTES ):
470
+ inputs = estimator .sagemaker_session .upload_data (path = os .path .join (resource_path , 'data' ),
471
+ key_prefix = 'scriptmode/mnist' )
472
+
473
+ tuning_job_name = unique_name_from_base ('tune-tf-script-mode' , max_length = 32 )
474
+ tuner .fit (inputs , job_name = tuning_job_name )
475
+
476
+ print ('Started hyperparameter tuning job with name: ' + tuning_job_name )
477
+
478
+ time .sleep (15 )
479
+ tuner .wait ()
480
+
481
+
444
482
@pytest .mark .canary_quick
445
483
@pytest .mark .skipif (PYTHON_VERSION != 'py2' , reason = "TensorFlow image supports only python 2." )
446
484
def test_tuning_tf (sagemaker_session ):
0 commit comments