Skip to content

Commit 2dfc046

Browse files
jbrockmendelJulianWgs
authored andcommitted
PERF: NDArrayBackedExtensionArray in cython (pandas-dev#40840)
1 parent eb1d23d commit 2dfc046

File tree

7 files changed

+254
-47
lines changed

7 files changed

+254
-47
lines changed

pandas/_libs/arrays.pyx

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""
2+
Cython implementations for internal ExtensionArrays.
3+
"""
4+
cimport cython
5+
6+
import numpy as np
7+
8+
cimport numpy as cnp
9+
from numpy cimport ndarray
10+
11+
cnp.import_array()
12+
13+
14+
@cython.freelist(16)
15+
cdef class NDArrayBacked:
16+
"""
17+
Implementing these methods in cython improves performance quite a bit.
18+
19+
import pandas as pd
20+
21+
from pandas._libs.arrays import NDArrayBacked as cls
22+
23+
dti = pd.date_range("2016-01-01", periods=3)
24+
dta = dti._data
25+
arr = dta._ndarray
26+
27+
obj = cls._simple_new(arr, arr.dtype)
28+
29+
# for foo in [arr, dta, obj]: ...
30+
31+
%timeit foo.copy()
32+
299 ns ± 30 ns per loop # <-- arr underlying ndarray (for reference)
33+
530 ns ± 9.24 ns per loop # <-- dta with cython NDArrayBacked
34+
1.66 µs ± 46.3 ns per loop # <-- dta without cython NDArrayBacked
35+
328 ns ± 5.29 ns per loop # <-- obj with NDArrayBacked.__cinit__
36+
371 ns ± 6.97 ns per loop # <-- obj with NDArrayBacked._simple_new
37+
38+
%timeit foo.T
39+
125 ns ± 6.27 ns per loop # <-- arr underlying ndarray (for reference)
40+
226 ns ± 7.66 ns per loop # <-- dta with cython NDArrayBacked
41+
911 ns ± 16.6 ns per loop # <-- dta without cython NDArrayBacked
42+
215 ns ± 4.54 ns per loop # <-- obj with NDArrayBacked._simple_new
43+
44+
"""
45+
# TODO: implement take in terms of cnp.PyArray_TakeFrom
46+
# TODO: implement concat_same_type in terms of cnp.PyArray_Concatenate
47+
48+
cdef:
49+
readonly ndarray _ndarray
50+
readonly object _dtype
51+
52+
def __init__(self, ndarray values, object dtype):
53+
self._ndarray = values
54+
self._dtype = dtype
55+
56+
@classmethod
57+
def _simple_new(cls, ndarray values, object dtype):
58+
cdef:
59+
NDArrayBacked obj
60+
obj = NDArrayBacked.__new__(cls)
61+
obj._ndarray = values
62+
obj._dtype = dtype
63+
return obj
64+
65+
cpdef NDArrayBacked _from_backing_data(self, ndarray values):
66+
"""
67+
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.
68+
69+
This should round-trip:
70+
self == self._from_backing_data(self._ndarray)
71+
"""
72+
# TODO: re-reuse simple_new if/when it can be cpdef
73+
cdef:
74+
NDArrayBacked obj
75+
obj = NDArrayBacked.__new__(type(self))
76+
obj._ndarray = values
77+
obj._dtype = self._dtype
78+
return obj
79+
80+
cpdef __setstate__(self, state):
81+
if isinstance(state, dict):
82+
if "_data" in state:
83+
data = state.pop("_data")
84+
elif "_ndarray" in state:
85+
data = state.pop("_ndarray")
86+
else:
87+
raise ValueError
88+
self._ndarray = data
89+
self._dtype = state.pop("_dtype")
90+
91+
for key, val in state.items():
92+
setattr(self, key, val)
93+
elif isinstance(state, tuple):
94+
if len(state) != 3:
95+
if len(state) == 1 and isinstance(state[0], dict):
96+
self.__setstate__(state[0])
97+
return
98+
raise NotImplementedError(state)
99+
100+
data, dtype = state[:2]
101+
if isinstance(dtype, np.ndarray):
102+
dtype, data = data, dtype
103+
self._ndarray = data
104+
self._dtype = dtype
105+
106+
if isinstance(state[2], dict):
107+
for key, val in state[2].items():
108+
setattr(self, key, val)
109+
else:
110+
raise NotImplementedError(state)
111+
else:
112+
raise NotImplementedError(state)
113+
114+
def __len__(self) -> int:
115+
return len(self._ndarray)
116+
117+
@property
118+
def shape(self):
119+
# object cast bc _ndarray.shape is npy_intp*
120+
return (<object>(self._ndarray)).shape
121+
122+
@property
123+
def ndim(self) -> int:
124+
return self._ndarray.ndim
125+
126+
@property
127+
def size(self) -> int:
128+
return self._ndarray.size
129+
130+
@property
131+
def nbytes(self) -> int:
132+
return self._ndarray.nbytes
133+
134+
def copy(self):
135+
# NPY_ANYORDER -> same order as self._ndarray
136+
res_values = cnp.PyArray_NewCopy(self._ndarray, cnp.NPY_ANYORDER)
137+
return self._from_backing_data(res_values)
138+
139+
def delete(self, loc, axis=0):
140+
res_values = np.delete(self._ndarray, loc, axis=axis)
141+
return self._from_backing_data(res_values)
142+
143+
def swapaxes(self, axis1, axis2):
144+
res_values = cnp.PyArray_SwapAxes(self._ndarray, axis1, axis2)
145+
return self._from_backing_data(res_values)
146+
147+
# TODO: pass NPY_MAXDIMS equiv to axis=None?
148+
def repeat(self, repeats, axis: int = 0):
149+
if axis is None:
150+
axis = 0
151+
res_values = cnp.PyArray_Repeat(self._ndarray, repeats, <int>axis)
152+
return self._from_backing_data(res_values)
153+
154+
def reshape(self, *args, **kwargs):
155+
res_values = self._ndarray.reshape(*args, **kwargs)
156+
return self._from_backing_data(res_values)
157+
158+
def ravel(self, order="C"):
159+
# cnp.PyArray_OrderConverter(PyObject* obj, NPY_ORDER* order)
160+
# res_values = cnp.PyArray_Ravel(self._ndarray, order)
161+
res_values = self._ndarray.ravel(order)
162+
return self._from_backing_data(res_values)
163+
164+
@property
165+
def T(self):
166+
res_values = self._ndarray.T
167+
return self._from_backing_data(res_values)

pandas/compat/pickle_compat.py

+19
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@
1010
from typing import TYPE_CHECKING
1111
import warnings
1212

13+
import numpy as np
14+
15+
from pandas._libs.arrays import NDArrayBacked
1316
from pandas._libs.tslibs import BaseOffset
1417

1518
from pandas import Index
19+
from pandas.core.arrays import (
20+
DatetimeArray,
21+
PeriodArray,
22+
TimedeltaArray,
23+
)
1624

1725
if TYPE_CHECKING:
1826
from pandas import (
@@ -51,6 +59,10 @@ def load_reduce(self):
5159
cls = args[0]
5260
stack[-1] = cls.__new__(*args)
5361
return
62+
elif args and issubclass(args[0], PeriodArray):
63+
cls = args[0]
64+
stack[-1] = NDArrayBacked.__new__(*args)
65+
return
5466

5567
raise
5668

@@ -204,6 +216,13 @@ def load_newobj(self):
204216
# compat
205217
if issubclass(cls, Index):
206218
obj = object.__new__(cls)
219+
elif issubclass(cls, DatetimeArray) and not args:
220+
arr = np.array([], dtype="M8[ns]")
221+
obj = cls.__new__(cls, arr, arr.dtype)
222+
elif issubclass(cls, TimedeltaArray) and not args:
223+
arr = np.array([], dtype="m8[ns]")
224+
obj = cls.__new__(cls, arr, arr.dtype)
225+
207226
else:
208227
obj = cls.__new__(cls, *args)
209228

pandas/core/arrays/datetimelike.py

+7-31
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
algos,
2424
lib,
2525
)
26+
from pandas._libs.arrays import NDArrayBacked
2627
from pandas._libs.tslibs import (
2728
BaseOffset,
2829
IncompatibleFrequency,
@@ -141,7 +142,7 @@ class InvalidComparison(Exception):
141142
pass
142143

143144

144-
class DatetimeLikeArrayMixin(OpsMixin, NDArrayBackedExtensionArray):
145+
class DatetimeLikeArrayMixin(OpsMixin, NDArrayBacked, NDArrayBackedExtensionArray):
145146
"""
146147
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray
147148
@@ -162,15 +163,6 @@ class DatetimeLikeArrayMixin(OpsMixin, NDArrayBackedExtensionArray):
162163
def __init__(self, data, dtype: Dtype | None = None, freq=None, copy=False):
163164
raise AbstractMethodError(self)
164165

165-
@classmethod
166-
def _simple_new(
167-
cls: type[DatetimeLikeArrayT],
168-
values: np.ndarray,
169-
freq: BaseOffset | None = None,
170-
dtype: Dtype | None = None,
171-
) -> DatetimeLikeArrayT:
172-
raise AbstractMethodError(cls)
173-
174166
@property
175167
def _scalar_type(self) -> type[DatetimeLikeScalar]:
176168
"""
@@ -254,31 +246,10 @@ def _check_compatible_with(
254246
# ------------------------------------------------------------------
255247
# NDArrayBackedExtensionArray compat
256248

257-
def __setstate__(self, state):
258-
if isinstance(state, dict):
259-
if "_data" in state and "_ndarray" not in state:
260-
# backward compat, changed what is property vs attribute
261-
state["_ndarray"] = state.pop("_data")
262-
for key, value in state.items():
263-
setattr(self, key, value)
264-
else:
265-
# PeriodArray, bc it mixes in a cython class
266-
if isinstance(state, tuple) and len(state) == 1:
267-
state = state[0]
268-
self.__setstate__(state)
269-
else:
270-
raise TypeError(state)
271-
272249
@cache_readonly
273250
def _data(self) -> np.ndarray:
274251
return self._ndarray
275252

276-
def _from_backing_data(
277-
self: DatetimeLikeArrayT, arr: np.ndarray
278-
) -> DatetimeLikeArrayT:
279-
# Note: we do not retain `freq`
280-
return type(self)._simple_new(arr, dtype=self.dtype)
281-
282253
# ------------------------------------------------------------------
283254

284255
def _box_func(self, x):
@@ -1718,6 +1689,11 @@ class TimelikeOps(DatetimeLikeArrayMixin):
17181689
Common ops for TimedeltaIndex/DatetimeIndex, but not PeriodIndex.
17191690
"""
17201691

1692+
def copy(self: TimelikeOps) -> TimelikeOps:
1693+
result = NDArrayBacked.copy(self)
1694+
result._freq = self._freq
1695+
return result
1696+
17211697
def _round(self, freq, mode, ambiguous, nonexistent):
17221698
# round the local times
17231699
if is_datetime64tz_dtype(self.dtype):

pandas/core/arrays/datetimes.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
lib,
2020
tslib,
2121
)
22+
from pandas._libs.arrays import NDArrayBacked
2223
from pandas._libs.tslibs import (
2324
BaseOffset,
2425
NaT,
@@ -313,8 +314,7 @@ def __init__(self, values, dtype=DT64NS_DTYPE, freq=None, copy: bool = False):
313314
# be incorrect(ish?) for the array as a whole
314315
dtype = DatetimeTZDtype(tz=timezones.tz_standardize(dtype.tz))
315316

316-
self._ndarray = values
317-
self._dtype = dtype
317+
NDArrayBacked.__init__(self, values=values, dtype=dtype)
318318
self._freq = freq
319319

320320
if inferred_freq is None and freq is not None:
@@ -327,10 +327,8 @@ def _simple_new(
327327
assert isinstance(values, np.ndarray)
328328
assert values.dtype == DT64NS_DTYPE
329329

330-
result = object.__new__(cls)
331-
result._ndarray = values
330+
result = super()._simple_new(values, dtype)
332331
result._freq = freq
333-
result._dtype = dtype
334332
return result
335333

336334
@classmethod

0 commit comments

Comments
 (0)