19
19
from tests .integ import DATA_DIR , REGION
20
20
from tests .integ .timeout import timeout_and_delete_endpoint , timeout
21
21
22
+ DATA_PATH = os .path .join (DATA_DIR , 'iris' , 'data' )
23
+
22
24
23
25
@pytest .fixture (scope = 'module' )
24
26
def sagemaker_session ():
@@ -28,7 +30,6 @@ def sagemaker_session():
28
30
def test_tf (sagemaker_session ):
29
31
with timeout (minutes = 15 ):
30
32
script_path = os .path .join (DATA_DIR , 'iris' , 'iris-dnn-classifier.py' )
31
- data_path = os .path .join (DATA_DIR , 'iris' , 'data' )
32
33
33
34
estimator = TensorFlow (entry_point = script_path ,
34
35
role = 'SageMakerRole' ,
@@ -40,7 +41,7 @@ def test_tf(sagemaker_session):
40
41
sagemaker_session = sagemaker_session ,
41
42
base_job_name = 'test-tf' )
42
43
43
- inputs = estimator .sagemaker_session .upload_data (path = data_path , key_prefix = 'integ-test-data/tf_iris' )
44
+ inputs = estimator .sagemaker_session .upload_data (path = DATA_PATH , key_prefix = 'integ-test-data/tf_iris' )
44
45
estimator .fit (inputs )
45
46
print ('job succeeded: {}' .format (estimator .latest_training_job .name ))
46
47
@@ -49,3 +50,22 @@ def test_tf(sagemaker_session):
49
50
50
51
result = json_predictor .predict ([6.4 , 3.2 , 4.5 , 1.5 ])
51
52
print ('predict result: {}' .format (result ))
53
+
54
+
55
+ def test_failed_tf_training (sagemaker_session ):
56
+ with timeout (minutes = 15 ):
57
+ script_path = os .path .join (DATA_DIR , 'iris' , 'failure_script.py' )
58
+ estimator = TensorFlow (entry_point = script_path ,
59
+ role = 'SageMakerRole' ,
60
+ training_steps = 1 ,
61
+ evaluation_steps = 1 ,
62
+ hyperparameters = {'input_tensor_name' : 'inputs' },
63
+ train_instance_count = 1 ,
64
+ train_instance_type = 'ml.c4.xlarge' ,
65
+ sagemaker_session = sagemaker_session )
66
+
67
+ inputs = estimator .sagemaker_session .upload_data (path = DATA_PATH , key_prefix = 'integ-test-data/tf-failure' )
68
+
69
+ with pytest .raises (ValueError ) as e :
70
+ estimator .fit (inputs )
71
+ assert 'This failure is expected' in str (e .value )
0 commit comments