Skip to content

Commit c4ebf21

Browse files
authored
REF: Implement NDArrayBackedExtensionArray (#33660)
1 parent c608824 commit c4ebf21

File tree

4 files changed

+102
-24
lines changed

4 files changed

+102
-24
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ Backwards incompatible API changes
191191
Previously a ``UnsupportedFunctionCall`` was raised (``AssertionError`` if ``min_count`` passed into :meth:`~DataFrameGroupby.median`) (:issue:`31485`)
192192
- :meth:`DataFrame.at` and :meth:`Series.at` will raise a ``TypeError`` instead of a ``ValueError`` if an incompatible key is passed, and ``KeyError`` if a missing key is passed, matching the behavior of ``.loc[]`` (:issue:`31722`)
193193
- Passing an integer dtype other than ``int64`` to ``np.array(period_index, dtype=...)`` will now raise ``TypeError`` instead of incorrectly using ``int64`` (:issue:`32255`)
194+
- Passing an invalid ``fill_value`` to :meth:`Categorical.take` raises a ``ValueError`` instead of ``TypeError`` (:issue:`33660`)
194195

195196
``MultiIndex.get_indexer`` interprets `method` argument differently
196197
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

pandas/core/arrays/_mixins.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import Any, Sequence, TypeVar
2+
3+
import numpy as np
4+
5+
from pandas.errors import AbstractMethodError
6+
7+
from pandas.core.algorithms import take
8+
from pandas.core.arrays.base import ExtensionArray
9+
10+
_T = TypeVar("_T", bound="NDArrayBackedExtensionArray")
11+
12+
13+
class NDArrayBackedExtensionArray(ExtensionArray):
14+
"""
15+
ExtensionArray that is backed by a single NumPy ndarray.
16+
"""
17+
18+
_ndarray: np.ndarray
19+
20+
def _from_backing_data(self: _T, arr: np.ndarray) -> _T:
21+
"""
22+
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.
23+
24+
This should round-trip:
25+
self == self._from_backing_data(self._ndarray)
26+
"""
27+
raise AbstractMethodError(self)
28+
29+
# ------------------------------------------------------------------------
30+
31+
def take(
32+
self: _T,
33+
indices: Sequence[int],
34+
allow_fill: bool = False,
35+
fill_value: Any = None,
36+
) -> _T:
37+
if allow_fill:
38+
fill_value = self._validate_fill_value(fill_value)
39+
40+
new_data = take(
41+
self._ndarray, indices, allow_fill=allow_fill, fill_value=fill_value,
42+
)
43+
return self._from_backing_data(new_data)
44+
45+
def _validate_fill_value(self, fill_value):
46+
"""
47+
If a fill_value is passed to `take` convert it to a representation
48+
suitable for self._ndarray, raising ValueError if this is not possible.
49+
50+
Parameters
51+
----------
52+
fill_value : object
53+
54+
Returns
55+
-------
56+
fill_value : native representation
57+
58+
Raises
59+
------
60+
ValueError
61+
"""
62+
raise AbstractMethodError(self)

pandas/core/arrays/categorical.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@
4949
from pandas.core import ops
5050
from pandas.core.accessor import PandasDelegate, delegate_names
5151
import pandas.core.algorithms as algorithms
52-
from pandas.core.algorithms import _get_data_algo, factorize, take, take_1d, unique1d
52+
from pandas.core.algorithms import _get_data_algo, factorize, take_1d, unique1d
5353
from pandas.core.array_algos.transforms import shift
54-
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
54+
from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray
55+
from pandas.core.arrays.base import _extension_array_shared_docs
5556
from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs
5657
import pandas.core.common as com
5758
from pandas.core.construction import array, extract_array, sanitize_array
@@ -199,7 +200,7 @@ def contains(cat, key, container):
199200
return any(loc_ in container for loc_ in loc)
200201

201202

202-
class Categorical(ExtensionArray, PandasObject):
203+
class Categorical(NDArrayBackedExtensionArray, PandasObject):
203204
"""
204205
Represent a categorical variable in classic R / S-plus fashion.
205206
@@ -1238,7 +1239,7 @@ def shift(self, periods, fill_value=None):
12381239

12391240
def _validate_fill_value(self, fill_value):
12401241
"""
1241-
Convert a user-facing fill_value to a representation to use with our
1242+
Convert a user-facing fill_value to a representation to use with our
12421243
underlying ndarray, raising ValueError if this is not possible.
12431244
12441245
Parameters
@@ -1768,7 +1769,7 @@ def fillna(self, value=None, method=None, limit=None):
17681769

17691770
return self._constructor(codes, dtype=self.dtype, fastpath=True)
17701771

1771-
def take(self, indexer, allow_fill: bool = False, fill_value=None):
1772+
def take(self: _T, indexer, allow_fill: bool = False, fill_value=None) -> _T:
17721773
"""
17731774
Take elements from the Categorical.
17741775
@@ -1837,16 +1838,23 @@ def take(self, indexer, allow_fill: bool = False, fill_value=None):
18371838
Categories (2, object): [a, b]
18381839
18391840
Specifying a fill value that's not in ``self.categories``
1840-
will raise a ``TypeError``.
1841+
will raise a ``ValueError``.
18411842
"""
1842-
indexer = np.asarray(indexer, dtype=np.intp)
1843+
return NDArrayBackedExtensionArray.take(
1844+
self, indexer, allow_fill=allow_fill, fill_value=fill_value
1845+
)
18431846

1844-
if allow_fill:
1845-
# convert user-provided `fill_value` to codes
1846-
fill_value = self._validate_fill_value(fill_value)
1847+
# ------------------------------------------------------------------
1848+
# NDArrayBackedExtensionArray compat
18471849

1848-
codes = take(self._codes, indexer, allow_fill=allow_fill, fill_value=fill_value)
1849-
return self._constructor(codes, dtype=self.dtype, fastpath=True)
1850+
@property
1851+
def _ndarray(self) -> np.ndarray:
1852+
return self._codes
1853+
1854+
def _from_backing_data(self, arr: np.ndarray) -> "Categorical":
1855+
return self._constructor(arr, dtype=self.dtype, fastpath=True)
1856+
1857+
# ------------------------------------------------------------------
18501858

18511859
def take_nd(self, indexer, allow_fill: bool = False, fill_value=None):
18521860
# GH#27745 deprecate alias that other EAs dont have

pandas/core/arrays/datetimelike.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
4040

4141
from pandas.core import missing, nanops, ops
42-
from pandas.core.algorithms import checked_add_with_arr, take, unique1d, value_counts
42+
from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts
4343
from pandas.core.array_algos.transforms import shift
44+
from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray
4445
from pandas.core.arrays.base import ExtensionArray, ExtensionOpsMixin
4546
import pandas.core.common as com
4647
from pandas.core.construction import array, extract_array
@@ -436,7 +437,9 @@ def _with_freq(self, freq):
436437
return self
437438

438439

439-
class DatetimeLikeArrayMixin(ExtensionOpsMixin, AttributesMixin, ExtensionArray):
440+
class DatetimeLikeArrayMixin(
441+
ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray
442+
):
440443
"""
441444
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray
442445
@@ -448,6 +451,20 @@ class DatetimeLikeArrayMixin(ExtensionOpsMixin, AttributesMixin, ExtensionArray)
448451
_generate_range
449452
"""
450453

454+
# ------------------------------------------------------------------
455+
# NDArrayBackedExtensionArray compat
456+
457+
@property
458+
def _ndarray(self) -> np.ndarray:
459+
# NB: A bunch of Interval tests fail if we use ._data
460+
return self.asi8
461+
462+
def _from_backing_data(self: _T, arr: np.ndarray) -> _T:
463+
# Note: we do not retain `freq`
464+
return type(self)(arr, dtype=self.dtype) # type: ignore
465+
466+
# ------------------------------------------------------------------
467+
451468
@property
452469
def ndim(self) -> int:
453470
return self._data.ndim
@@ -667,16 +684,6 @@ def unique(self):
667684
result = unique1d(self.asi8)
668685
return type(self)(result, dtype=self.dtype)
669686

670-
def take(self, indices, allow_fill=False, fill_value=None):
671-
if allow_fill:
672-
fill_value = self._validate_fill_value(fill_value)
673-
674-
new_values = take(
675-
self.asi8, indices, allow_fill=allow_fill, fill_value=fill_value
676-
)
677-
678-
return type(self)(new_values, dtype=self.dtype)
679-
680687
@classmethod
681688
def _concat_same_type(cls, to_concat, axis: int = 0):
682689

0 commit comments

Comments
 (0)