Skip to content

Commit e394710

Browse files
authored
Add integration tests for basic training failure cases (aws#33)
1 parent b022f1a commit e394710

File tree

4 files changed

+45
-2
lines changed

4 files changed

+45
-2
lines changed

tests/data/iris/failure_script.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
def estimator_fn(run_config, params):
2+
"""For use with integration tests expecting failures."""
3+
raise Exception('This failure is expected.')
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
def train(**kwargs):
2+
"""For use with integration tests expecting failures."""
3+
raise Exception('This failure is expected.')

tests/integ/test_mxnet_train.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,20 @@ def test_deploy_model(mxnet_training_job, sagemaker_session):
7070

7171
data = numpy.zeros(shape=(1, 1, 28, 28))
7272
predictor.predict(data)
73+
74+
75+
def test_failed_training_job(sagemaker_session):
76+
with timeout(minutes=15):
77+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
78+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
79+
80+
mx = MXNet(entry_point=script_path, role='SageMakerRole',
81+
train_instance_count=1, train_instance_type='ml.c4.xlarge',
82+
sagemaker_session=sagemaker_session)
83+
84+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
85+
key_prefix='integ-test-data/mxnet_mnist/train-failure')
86+
87+
with pytest.raises(ValueError) as e:
88+
mx.fit(train_input)
89+
assert 'This failure is expected' in str(e.value)

tests/integ/test_tf.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from tests.integ import DATA_DIR, REGION
2020
from tests.integ.timeout import timeout_and_delete_endpoint, timeout
2121

22+
DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
23+
2224

2325
@pytest.fixture(scope='module')
2426
def sagemaker_session():
@@ -28,7 +30,6 @@ def sagemaker_session():
2830
def test_tf(sagemaker_session):
2931
with timeout(minutes=15):
3032
script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py')
31-
data_path = os.path.join(DATA_DIR, 'iris', 'data')
3233

3334
estimator = TensorFlow(entry_point=script_path,
3435
role='SageMakerRole',
@@ -40,7 +41,7 @@ def test_tf(sagemaker_session):
4041
sagemaker_session=sagemaker_session,
4142
base_job_name='test-tf')
4243

43-
inputs = estimator.sagemaker_session.upload_data(path=data_path, key_prefix='integ-test-data/tf_iris')
44+
inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf_iris')
4445
estimator.fit(inputs)
4546
print('job succeeded: {}'.format(estimator.latest_training_job.name))
4647

@@ -49,3 +50,22 @@ def test_tf(sagemaker_session):
4950

5051
result = json_predictor.predict([6.4, 3.2, 4.5, 1.5])
5152
print('predict result: {}'.format(result))
53+
54+
55+
def test_failed_tf_training(sagemaker_session):
56+
with timeout(minutes=15):
57+
script_path = os.path.join(DATA_DIR, 'iris', 'failure_script.py')
58+
estimator = TensorFlow(entry_point=script_path,
59+
role='SageMakerRole',
60+
training_steps=1,
61+
evaluation_steps=1,
62+
hyperparameters={'input_tensor_name': 'inputs'},
63+
train_instance_count=1,
64+
train_instance_type='ml.c4.xlarge',
65+
sagemaker_session=sagemaker_session)
66+
67+
inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure')
68+
69+
with pytest.raises(ValueError) as e:
70+
estimator.fit(inputs)
71+
assert 'This failure is expected' in str(e.value)

0 commit comments

Comments
 (0)