Skip to content

Commit 047629f

Browse files
akrishna1995Ashwin Krishna
authored and
root
committed
feature: set default allow_pickle param to False (aws#4557)
* breaking: set default allow_pickle param to False * breaking: fix unit tests and linting NumpyDeserializer will not allow deserialization unless allow_pickle flag is set to True explicitly * fix: black-check --------- Co-authored-by: Ashwin Krishna <[email protected]>
1 parent 7b50900 commit 047629f

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/sagemaker/base_deserializers.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@ class NumpyDeserializer(SimpleBaseDeserializer):
196196
single array.
197197
"""
198198

199-
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
199+
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False):
200200
"""Initialize a ``NumpyDeserializer`` instance.
201201
202202
Args:
203203
dtype (str): The dtype of the data (default: None).
204204
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
205205
is expected from the inference endpoint (default: "application/x-npy").
206-
allow_pickle (bool): Allow loading pickled object arrays (default: True).
206+
allow_pickle (bool): Allow loading pickled object arrays (default: False).
207207
"""
208208
super(NumpyDeserializer, self).__init__(accept=accept)
209209
self.dtype = dtype
@@ -227,10 +227,21 @@ def deserialize(self, stream, content_type):
227227
if content_type == "application/json":
228228
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
229229
if content_type == "application/x-npy":
230-
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
230+
try:
231+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
232+
except ValueError as ve:
233+
raise ValueError(
234+
"Please set the param allow_pickle=True \
235+
to deserialize pickle objects in NumpyDeserializer"
236+
).with_traceback(ve.__traceback__)
231237
if content_type == "application/x-npz":
232238
try:
233239
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
240+
except ValueError as ve:
241+
raise ValueError(
242+
"Please set the param allow_pickle=True \
243+
to deserialize pickle objectsin NumpyDeserializer"
244+
).with_traceback(ve.__traceback__)
234245
finally:
235246
stream.close()
236247
finally:

tests/unit/sagemaker/deserializers/test_deserializers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def test_numpy_deserializer_from_npy(numpy_deserializer):
142142
assert np.array_equal(array, result)
143143

144144

145-
def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
145+
def test_numpy_deserializer_from_npy_object_array():
146+
numpy_deserializer = NumpyDeserializer(allow_pickle=True)
146147
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
147148
stream = io.BytesIO()
148149
np.save(stream, array)

0 commit comments

Comments
 (0)