@@ -35,19 +35,7 @@ def __init__(self):
35
35
self .content_type = PICKLE_CONTENT_TYPE
36
36
37
37
def __call__ (self , data ):
38
- return pickle .dumps (data )
39
-
40
-
41
- class PickleDeserializer (object ):
42
- def __init__ (self ):
43
- self .accept = PICKLE_CONTENT_TYPE
44
-
45
- def __call__ (self , stream , content_type ):
46
- try :
47
- data = stream .read ().decode ()
48
- return pickle .loads (data )
49
- finally :
50
- stream .close ()
38
+ return pickle .dumps (data , protocol = 2 )
51
39
52
40
53
41
def test_cifar (sagemaker_session ):
@@ -69,9 +57,8 @@ def test_cifar(sagemaker_session):
69
57
with timeout_and_delete_endpoint (estimator = estimator , minutes = 20 ):
70
58
predictor = estimator .deploy (initial_instance_count = 1 , instance_type = 'ml.p2.xlarge' )
71
59
predictor .serializer = PickleSerializer ()
72
- predictor .deserializer = PickleDeserializer ()
60
+ predictor .content_type = PICKLE_CONTENT_TYPE
73
61
74
62
data = np .random .randn (32 , 32 , 3 )
75
63
predict_response = predictor .predict (data )
76
-
77
- assert len (predict_response .outputs ['probabilities' ].float_val ) == 10
64
+ assert len (predict_response ['outputs' ]['probabilities' ]['floatVal' ]) == 10
0 commit comments