Skip to content

Commit bf6cc68

Browse files
committed
fix test prediction for python 3.0
1 parent 1da65a9 commit bf6cc68

File tree

2 files changed

+3
-20
lines changed

2 files changed

+3
-20
lines changed

tests/data/cifar_10/source/resnet_cifar_10.py

-4
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,3 @@ def _generate_synthetic_data(mode, batch_size):
142142

143143
def input_fn(serialized_data, content_type):
144144
return pickle.loads(serialized_data)
145-
146-
147-
def output_fn(data, accepts):
148-
return pickle.dumps(data)

tests/integ/test_tf_cifar.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,7 @@ def __init__(self):
3535
self.content_type = PICKLE_CONTENT_TYPE
3636

3737
def __call__(self, data):
38-
return pickle.dumps(data)
39-
40-
41-
class PickleDeserializer(object):
42-
def __init__(self):
43-
self.accept = PICKLE_CONTENT_TYPE
44-
45-
def __call__(self, stream, content_type):
46-
try:
47-
data = stream.read().decode()
48-
return pickle.loads(data)
49-
finally:
50-
stream.close()
38+
return pickle.dumps(data, protocol=2)
5139

5240

5341
def test_cifar(sagemaker_session):
@@ -69,9 +57,8 @@ def test_cifar(sagemaker_session):
6957
with timeout_and_delete_endpoint(estimator=estimator, minutes=20):
7058
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge')
7159
predictor.serializer = PickleSerializer()
72-
predictor.deserializer = PickleDeserializer()
60+
predictor.content_type = PICKLE_CONTENT_TYPE
7361

7462
data = np.random.randn(32, 32, 3)
7563
predict_response = predictor.predict(data)
76-
77-
assert len(predict_response.outputs['probabilities'].float_val) == 10
64+
assert len(predict_response['outputs']['probabilities']['floatVal']) == 10

0 commit comments

Comments
 (0)