Skip to content

Commit 73b0daa

Browse files
authored
Use npy as a default format for prediction instead of json. (aws#63)
1 parent 3a9dedd commit 73b0daa

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/sagemaker/pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sagemaker.fw_utils import create_image_uri
1616
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1717
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
18-
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
18+
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
1919
from sagemaker.utils import name_from_image
2020

2121

@@ -34,7 +34,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
3434
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
3535
using the default AWS configuration chain.
3636
"""
37-
super(PyTorchPredictor, self).__init__(endpoint_name, sagemaker_session, json_serializer, json_deserializer)
37+
super(PyTorchPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer)
3838

3939

4040
class PyTorchModel(FrameworkModel):

tests/integ/test_pytorch_train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session):
4343
with timeout(minutes=20):
4444
estimator = PyTorch.attach(pytorch_training_job, sagemaker_session=sagemaker_session)
4545
predictor = estimator.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
46-
data = numpy.zeros(shape=(1, 1, 28, 28))
46+
data = numpy.zeros(shape=(1, 1, 28, 28), dtype=numpy.float32)
4747
predictor.predict(data)
4848

4949
batch_size = 100
50-
data = numpy.random.rand(batch_size, 1, 28, 28)
50+
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
5151
output = predictor.predict(data)
5252

53-
assert numpy.asarray(output).shape == (batch_size, 10)
53+
assert output.shape == (batch_size, 10)
5454

5555

5656
def test_deploy_model(pytorch_training_job, sagemaker_session):
@@ -63,10 +63,10 @@ def test_deploy_model(pytorch_training_job, sagemaker_session):
6363
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
6464

6565
batch_size = 100
66-
data = numpy.random.rand(batch_size, 1, 28, 28)
66+
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
6767
output = predictor.predict(data)
6868

69-
assert numpy.asarray(output).shape == (batch_size, 10)
69+
assert output.shape == (batch_size, 10)
7070

7171

7272
def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
@@ -92,10 +92,10 @@ def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
9292
predictor = estimator.deploy(1, instance_type, endpoint_name=endpoint_name)
9393

9494
batch_size = 100
95-
data = numpy.random.rand(batch_size, 1, 28, 28)
95+
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
9696
output = predictor.predict(data)
9797

98-
assert numpy.asarray(output).shape == (batch_size, 10)
98+
assert output.shape == (batch_size, 10)
9999

100100

101101
# TODO(nadiaya): Run against local mode when errors will be propagated

0 commit comments

Comments
 (0)