Skip to content

Commit 7646115

Browse files
committed
Fix typo,
1 parent 1844ce5 commit 7646115

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/integ/test_pytorch_train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session):
4646
predictor.predict(data)
4747

4848
batch_size = 100
49-
data = numpy.rand(shape=(100, 1, 28, 28))
49+
data = numpy.random.rand(batch_size, 1, 28, 28)
5050
output = predictor.predict(data)
5151

5252
assert numpy.asarray(output).shape == (batch_size, 10)
@@ -62,7 +62,7 @@ def test_deploy_model(pytorch_training_job, sagemaker_session):
6262
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
6363

6464
batch_size = 100
65-
data = numpy.rand(shape=(100, 1, 28, 28))
65+
data = numpy.random.rand(batch_size, 1, 28, 28)
6666
output = predictor.predict(data)
6767

6868
assert numpy.asarray(output).shape == (batch_size, 10)
@@ -91,7 +91,7 @@ def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
9191
predictor = estimator.deploy(1, instance_type, endpoint_name=endpoint_name)
9292

9393
batch_size = 100
94-
data = numpy.rand(shape=(100, 1, 28, 28))
94+
data = numpy.random.rand(batch_size, 1, 28, 28)
9595
output = predictor.predict(data)
9696

9797
assert numpy.asarray(output).shape == (batch_size, 10)

0 commit comments

Comments
 (0)