Skip to content

Commit c6a449c

Browse files
authored
ENH: Allow NEP 42 dtypes to use np.save and np.load (numpy#24142)
Fixes numpy#24110 First, this makes it so that by default NEP 42 dtypes can't be pickled unless the dtype has a pickle implementation. Currently numpy will pickle them, but won't be able to unpickle them because the type code written to disk is invalid. Erroring is an improvement over writing corrupt files, I think. Second, if a type can be pickled, this makes it so that np.save will save the array using pickle and will lie that the dtype is object (see @rkern's suggestion). I've made it so if this happens a UserWarning will get printed. Unfortunately there's no way to indicate in the file that this really isn't an object array, so I can't do much on the load side to detect when this happens. Hopefully the UserWarning at save time is enough? I think adding a way to indicate in the file that we're not really storing an object array would require a revision to the npy file format, which ideally I'd like to avoid. Last, added a pickle implementation to the scaled float test dtype and then added a test doing a round-trip save and load with a scale float array.
1 parent 5047644 commit c6a449c

File tree

4 files changed

+65
-3
lines changed

4 files changed

+65
-3
lines changed

numpy/core/src/multiarray/descriptor.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,6 +2631,13 @@ arraydescr_reduce(PyArray_Descr *self, PyObject *NPY_UNUSED(args))
26312631
obj = (PyObject *)self->typeobj;
26322632
Py_INCREF(obj);
26332633
}
2634+
else if (!NPY_DT_is_legacy(NPY_DTYPE(self))) {
2635+
PyErr_SetString(PyExc_RuntimeError,
2636+
"Custom dtypes cannot use the default pickle implementation "
2637+
"for NumPy dtypes. Add a custom pickle implementation to the "
2638+
"DType to avoid this error");
2639+
return NULL;
2640+
}
26342641
else {
26352642
elsize = self->elsize;
26362643
if (self->type_num == NPY_UNICODE) {

numpy/core/src/umath/_scaled_float_dtype.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,21 @@ sfloat_get_scaling(PyArray_SFloatDescr *self, PyObject *NPY_UNUSED(args))
194194
}
195195

196196

197+
static PyObject *
198+
sfloat___reduce__(PyArray_SFloatDescr *self)
199+
{
200+
return Py_BuildValue("(O(d))", Py_TYPE(self), self->scaling);
201+
}
202+
197203
PyMethodDef sfloat_methods[] = {
198204
{"scaled_by",
199205
(PyCFunction)python_sfloat_scaled_copy, METH_O,
200206
"Method to get a dtype copy with different scaling, mainly to "
201207
"avoid having to implement many ways to create new instances."},
202208
{"get_scaling",
203209
(PyCFunction)sfloat_get_scaling, METH_NOARGS, NULL},
210+
{"__reduce__",
211+
(PyCFunction)sfloat___reduce__, METH_NOARGS, NULL},
204212
{NULL, NULL, 0, NULL}
205213
};
206214

numpy/core/tests/test_custom_dtypes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from tempfile import NamedTemporaryFile
2+
13
import pytest
24

35
import numpy as np
@@ -243,6 +245,28 @@ def test_creation_class(self):
243245
assert np.zeros(3, dtype=SF).dtype == SF(1.)
244246
assert np.zeros_like(arr1, dtype=SF).dtype == SF(1.)
245247

248+
def test_np_save_load(self):
249+
# this monkeypatch is needed because pickle
250+
# uses the repr of a type to reconstruct it
251+
np._ScaledFloatTestDType = SF
252+
253+
arr = np.array([1.0, 2.0, 3.0], dtype=SF(1.0))
254+
255+
# adapted from RoundtripTest.roundtrip in np.save tests
256+
with NamedTemporaryFile("wb", delete=False, suffix=".npz") as f:
257+
with pytest.warns(UserWarning) as record:
258+
np.savez(f.name, arr)
259+
260+
assert len(record) == 1
261+
262+
with np.load(f.name, allow_pickle=True) as data:
263+
larr = data["arr_0"]
264+
assert_array_equal(arr.view(np.float64), larr.view(np.float64))
265+
assert larr.dtype == arr.dtype == SF(1.0)
266+
267+
del np._ScaledFloatTestDType
268+
269+
246270

247271
def test_type_pickle():
248272
# can't actually unpickle, but we can pickle (if in namespace)

numpy/lib/format.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,23 @@ def dtype_to_descr(dtype):
277277
# fiddled with. This needs to be fixed in the C implementation of
278278
# dtype().
279279
return dtype.descr
280+
elif not type(dtype)._legacy:
281+
# this must be a user-defined dtype since numpy does not yet expose any
282+
# non-legacy dtypes in the public API
283+
#
284+
# non-legacy dtypes don't yet have __array_interface__
285+
# support. Instead, as a hack, we use pickle to save the array, and lie
286+
# that the dtype is object. When the array is loaded, the descriptor is
287+
# unpickled with the array and the object dtype in the header is
288+
# discarded.
289+
#
290+
# a future NEP should define a way to serialize user-defined
291+
# descriptors and ideally work out the possible security implications
292+
warnings.warn("Custom dtypes are saved as python objects using the "
293+
"pickle protocol. Loading this file requires "
294+
"allow_pickle=True to be set.",
295+
UserWarning, stacklevel=2)
296+
return "|O"
280297
else:
281298
return dtype.str
282299

@@ -710,12 +727,18 @@ def write_array(fp, array, version=None, allow_pickle=True, pickle_kwargs=None):
710727
# Set buffer size to 16 MiB to hide the Python loop overhead.
711728
buffersize = max(16 * 1024 ** 2 // array.itemsize, 1)
712729

713-
if array.dtype.hasobject:
730+
dtype_class = type(array.dtype)
731+
732+
if array.dtype.hasobject or not dtype_class._legacy:
714733
# We contain Python objects so we cannot write out the data
715734
# directly. Instead, we will pickle it out
716735
if not allow_pickle:
717-
raise ValueError("Object arrays cannot be saved when "
718-
"allow_pickle=False")
736+
if array.dtype.hasobject:
737+
raise ValueError("Object arrays cannot be saved when "
738+
"allow_pickle=False")
739+
if not dtype_class._legacy:
740+
raise ValueError("User-defined dtypes cannot be saved "
741+
"when allow_pickle=False")
719742
if pickle_kwargs is None:
720743
pickle_kwargs = {}
721744
pickle.dump(array, fp, protocol=3, **pickle_kwargs)

0 commit comments

Comments
 (0)