Skip to content

Commit 1844ce5

Browse files
committed
Add prediction output assertions.
1 parent 7ed5cc0 commit 1844ce5

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

tests/integ/test_pytorch_train.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session):
4545
data = numpy.zeros(shape=(1, 1, 28, 28))
4646
predictor.predict(data)
4747

48+
batch_size = 100
49+
data = numpy.rand(shape=(100, 1, 28, 28))
50+
output = predictor.predict(data)
51+
52+
assert numpy.asarray(output).shape == (batch_size, 10)
53+
4854

4955
def test_deploy_model(pytorch_training_job, sagemaker_session):
5056
endpoint_name = 'test-pytorch-deploy-model-{}'.format(sagemaker_timestamp())
@@ -55,8 +61,11 @@ def test_deploy_model(pytorch_training_job, sagemaker_session):
5561
model = PyTorchModel(model_data, 'SageMakerRole', entry_point=MNIST_SCRIPT, sagemaker_session=sagemaker_session)
5662
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
5763

58-
data = numpy.zeros(shape=(1, 1, 28, 28))
59-
predictor.predict(data)
64+
batch_size = 100
65+
data = numpy.rand(shape=(100, 1, 28, 28))
66+
output = predictor.predict(data)
67+
68+
assert numpy.asarray(output).shape == (batch_size, 10)
6069

6170

6271
def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
@@ -80,8 +89,12 @@ def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
8089
print("Re-attaching now to: %s" % training_job_name)
8190
estimator = PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
8291
predictor = estimator.deploy(1, instance_type, endpoint_name=endpoint_name)
83-
data = numpy.zeros(shape=(1, 1, 28, 28))
84-
predictor.predict(data)
92+
93+
batch_size = 100
94+
data = numpy.rand(shape=(100, 1, 28, 28))
95+
output = predictor.predict(data)
96+
97+
assert numpy.asarray(output).shape == (batch_size, 10)
8598

8699

87100
# TODO(nadiaya): Run against local mode when errors will be propagated

0 commit comments

Comments
 (0)