Skip to content

Commit fb3c234

Browse files
committed
[WIP]: ExtensionArray.take default implementation
Implements a take interface that's compatible with NumPy and optionally pandas' NA semantics. ```python In [1]: import pandas as pd In [2]: from pandas.tests.extension.decimal.array import * In [3]: arr = DecimalArray(['1.1', '1.2', '1.3']) In [4]: arr.take([0, 1, -1]) Out[4]: DecimalArray(array(['1.1', '1.2', '1.3'], dtype=object)) In [5]: arr.take([0, 1, -1], fill_value=float('nan')) Out[5]: DecimalArray(array(['1.1', '1.2', Decimal('NaN')], dtype=object)) ``` Closes pandas-dev#20640
1 parent 8bee97a commit fb3c234

File tree

8 files changed

+174
-57
lines changed

8 files changed

+174
-57
lines changed

pandas/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -451,3 +451,5 @@ def is_platform_mac():
451451

452452
def is_platform_32bit():
453453
return struct.calcsize("P") * 8 < 64
454+
455+
_default_fill_value = object()

pandas/core/algorithms.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_ensure_platform_int, _ensure_object,
3131
_ensure_float64, _ensure_uint64,
3232
_ensure_int64)
33+
from pandas.compat import _default_fill_value
3334
from pandas.compat.numpy import _np_version_under1p10
3435
from pandas.core.dtypes.missing import isna, na_value_for_dtype
3536

@@ -1482,7 +1483,7 @@ def take_nd(arr, indexer, axis=0, out=None, fill_value=np.nan, mask_info=None,
14821483
# TODO(EA): Remove these if / elifs as datetimeTZ, interval, become EAs
14831484
# dispatch to internal type takes
14841485
if is_extension_array_dtype(arr):
1485-
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
1486+
return arr.take(indexer, fill_value=fill_value)
14861487
elif is_datetimetz(arr):
14871488
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
14881489
elif is_interval_dtype(arr):
@@ -1558,6 +1559,81 @@ def take_nd(arr, indexer, axis=0, out=None, fill_value=np.nan, mask_info=None,
15581559
take_1d = take_nd
15591560

15601561

1562+
def take_ea(arr, indexer, fill_value=_default_fill_value):
1563+
"""Extension-array compatible take.
1564+
1565+
Parameters
1566+
----------
1567+
arr : array-like
1568+
Must satisify NumPy's take semantics.
1569+
indexer : sequence of integers
1570+
Indices to be taken. See Notes for how negative indicies
1571+
are handled.
1572+
fill_value : any, optional
1573+
Fill value to use for NA-indicies. This has a few behaviors.
1574+
1575+
* fill_value is not specified : triggers NumPy's semantics
1576+
where negative values in `indexer` mean slices from the end.
1577+
* fill_value is NA : Fill positions where `indexer` is ``-1``
1578+
with ``self.dtype.na_value``. Anything considered NA by
1579+
:func:`pandas.isna` will result in ``self.dtype.na_value``
1580+
being used to fill.
1581+
* fill_value is not NA : Fill positions where `indexer` is ``-1``
1582+
with `fill_value`.
1583+
1584+
Returns
1585+
-------
1586+
ExtensionArray
1587+
1588+
Raises
1589+
------
1590+
IndexError
1591+
When the indexer is out of bounds for the array.
1592+
ValueError
1593+
When the indexer contains negative values other than ``-1``
1594+
and `fill_value` is specified.
1595+
1596+
Notes
1597+
-----
1598+
The meaning of negative values in `indexer` depends on the
1599+
`fill_value` argument. By default, we follow the behavior
1600+
:meth:`numpy.take` of where negative indices indicate slices
1601+
from the end.
1602+
1603+
When `fill_value` is specified, we follow pandas semantics of ``-1``
1604+
indicating a missing value. In this case, positions where `indexer`
1605+
is ``-1`` will be filled with `fill_value` or the default NA value
1606+
for this type.
1607+
1608+
ExtensionArray.take is called by ``Series.__getitem__``, ``.loc``,
1609+
``iloc``, when the indexer is a sequence of values. Additionally,
1610+
it's called by :meth:`Series.reindex` with a `fill_value`.
1611+
1612+
See Also
1613+
--------
1614+
numpy.take
1615+
"""
1616+
indexer = np.asarray(indexer)
1617+
if fill_value is _default_fill_value:
1618+
# NumPy style
1619+
result = arr.take(indexer)
1620+
else:
1621+
mask = indexer == -1
1622+
if (indexer < -1).any():
1623+
raise ValueError("Invalid value in 'indexer'. All values "
1624+
"must be non-negative or -1. When "
1625+
"'fill_value' is specified.")
1626+
1627+
# take on empty array not handled as desired by numpy
1628+
# in case of -1 (all missing take)
1629+
if not len(arr) and mask.all():
1630+
return arr._from_sequence([fill_value] * len(indexer))
1631+
1632+
result = arr.take(indexer)
1633+
result[mask] = fill_value
1634+
return result
1635+
1636+
15611637
def take_2d_multi(arr, indexer, out=None, fill_value=np.nan, mask_info=None,
15621638
allow_fill=True):
15631639
"""

pandas/core/arrays/base.py

+52-39
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99

1010
from pandas.errors import AbstractMethodError
11+
from pandas.compat import _default_fill_value
1112
from pandas.compat.numpy import function as nv
1213

1314
_not_implemented_message = "{} does not implement {}."
@@ -53,6 +54,7 @@ class ExtensionArray(object):
5354
* unique
5455
* factorize / _values_for_factorize
5556
* argsort / _values_for_argsort
57+
* take / _values_for_take
5658
5759
This class does not inherit from 'abc.ABCMeta' for performance reasons.
5860
Methods and properties required by the interface raise
@@ -462,22 +464,38 @@ def factorize(self, na_sentinel=-1):
462464
# ------------------------------------------------------------------------
463465
# Indexing methods
464466
# ------------------------------------------------------------------------
465-
def take(self, indexer, allow_fill=True, fill_value=None):
466-
# type: (Sequence[int], bool, Optional[Any]) -> ExtensionArray
467+
def _values_for_take(self):
468+
"""Values to use for `take`.
469+
470+
Coerces to object dtype by default.
471+
472+
Returns
473+
-------
474+
array-like
475+
Must satisify NumPy's `take` semantics.
476+
"""
477+
return self.astype(object)
478+
479+
def take(self, indexer, fill_value=_default_fill_value):
480+
# type: (Sequence[int], Optional[Any]) -> ExtensionArray
467481
"""Take elements from an array.
468482
469483
Parameters
470484
----------
471485
indexer : sequence of integers
472-
indices to be taken. -1 is used to indicate values
473-
that are missing.
474-
allow_fill : bool, default True
475-
If False, indexer is assumed to contain no -1 values so no filling
476-
will be done. This short-circuits computation of a mask. Result is
477-
undefined if allow_fill == False and -1 is present in indexer.
478-
fill_value : any, default None
479-
Fill value to replace -1 values with. If applicable, this should
480-
use the sentinel missing value for this type.
486+
Indices to be taken. See Notes for how negative indicies
487+
are handled.
488+
fill_value : any, optional
489+
Fill value to use for NA-indicies. This has a few behaviors.
490+
491+
* fill_value is not specified : triggers NumPy's semantics
492+
where negative values in `indexer` mean slices from the end.
493+
* fill_value is NA : Fill positions where `indexer` is ``-1``
494+
with ``self.dtype.na_value``. Anything considered NA by
495+
:func:`pandas.isna` will result in ``self.dtype.na_value``
496+
being used to fill.
497+
* fill_value is not NA : Fill positions where `indexer` is ``-1``
498+
with `fill_value`.
481499
482500
Returns
483501
-------
@@ -487,44 +505,39 @@ def take(self, indexer, allow_fill=True, fill_value=None):
487505
------
488506
IndexError
489507
When the indexer is out of bounds for the array.
508+
ValueError
509+
When the indexer contains negative values other than ``-1``
510+
and `fill_value` is specified.
490511
491512
Notes
492513
-----
493-
This should follow pandas' semantics where -1 indicates missing values.
494-
Positions where indexer is ``-1`` should be filled with the missing
495-
value for this type.
496-
This gives rise to the special case of a take on an empty
497-
ExtensionArray that does not raises an IndexError straight away
498-
when the `indexer` is all ``-1``.
499-
500-
This is called by ``Series.__getitem__``, ``.loc``, ``iloc``, when the
501-
indexer is a sequence of values.
514+
The meaning of negative values in `indexer` depends on the
515+
`fill_value` argument. By default, we follow the behavior
516+
:meth:`numpy.take` of where negative indices indicate slices
517+
from the end.
502518
503-
Examples
504-
--------
505-
Suppose the extension array is backed by a NumPy array stored as
506-
``self.data``. Then ``take`` may be written as
507-
508-
.. code-block:: python
509-
510-
def take(self, indexer, allow_fill=True, fill_value=None):
511-
indexer = np.asarray(indexer)
512-
mask = indexer == -1
513-
514-
# take on empty array not handled as desired by numpy
515-
# in case of -1 (all missing take)
516-
if not len(self) and mask.all():
517-
return type(self)([np.nan] * len(indexer))
519+
When `fill_value` is specified, we follow pandas semantics of ``-1``
520+
indicating a missing value. In this case, positions where `indexer`
521+
is ``-1`` will be filled with `fill_value` or the default NA value
522+
for this type.
518523
519-
result = self.data.take(indexer)
520-
result[mask] = np.nan # NA for this type
521-
return type(self)(result)
524+
ExtensionArray.take is called by ``Series.__getitem__``, ``.loc``,
525+
``iloc``, when the indexer is a sequence of values. Additionally,
526+
it's called by :meth:`Series.reindex` with a `fill_value`.
522527
523528
See Also
524529
--------
525530
numpy.take
526531
"""
527-
raise AbstractMethodError(self)
532+
from pandas.core.algorithms import take_ea
533+
from pandas.core.missing import isna
534+
535+
if isna(fill_value):
536+
fill_value = self.dtype.na_value
537+
538+
data = self._values_for_take()
539+
result = take_ea(data, indexer, fill_value=fill_value)
540+
return self._from_sequence(result)
528541

529542
def copy(self, deep=False):
530543
# type: (bool) -> ExtensionArray

pandas/core/dtypes/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ class _DtypeOpsMixin(object):
1616
# classes will inherit from this Mixin. Once everything is compatible, this
1717
# class's methods can be moved to ExtensionDtype and removed.
1818

19+
# na_value is the default NA value to use for this type. This is used in
20+
# e.g. ExtensionArray.take.
21+
na_value = np.nan
22+
1923
def __eq__(self, other):
2024
"""Check whether 'other' is equal to self.
2125

pandas/core/dtypes/missing.py

+2
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,8 @@ def na_value_for_dtype(dtype, compat=True):
502502
"""
503503
dtype = pandas_dtype(dtype)
504504

505+
if is_extension_array_dtype(dtype):
506+
return dtype.na_value
505507
if (is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype) or
506508
is_timedelta64_dtype(dtype) or is_period_dtype(dtype)):
507509
return NaT

pandas/tests/extension/base/getitem.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,29 @@ def test_take(self, data, na_value, na_cmp):
134134

135135
def test_take_empty(self, data, na_value, na_cmp):
136136
empty = data[:0]
137-
result = empty.take([-1])
138-
na_cmp(result[0], na_value)
137+
# result = empty.take([-1])
138+
# na_cmp(result[0], na_value)
139139

140140
with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"):
141141
empty.take([0, 1])
142142

143+
def test_take_negative(self, data):
144+
# https://github.com/pandas-dev/pandas/issues/20640
145+
n = len(data)
146+
result = data.take([0, -n, n - 1, -1])
147+
expected = data.take([0, 0, n - 1, n - 1])
148+
self.assert_extension_array_equal(result, expected)
149+
150+
def test_take_non_na_fill_value(self, data_missing):
151+
fill_value = data_missing[1] # valid
152+
result = data_missing.take([-1, 1], fill_value=fill_value)
153+
expected = data_missing.take([1, 1])
154+
self.assert_extension_array_equal(result, expected)
155+
156+
def test_take_pandas_style_negative_raises(self, data, na_value):
157+
with pytest.raises(ValueError):
158+
data.take([0, -2], fill_value=na_value)
159+
143160
@pytest.mark.xfail(reason="Series.take with extension array buggy for -1")
144161
def test_take_series(self, data):
145162
s = pd.Series(data)

pandas/tests/extension/category/test_categorical.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,32 @@ def test_merge(self, data, na_value):
8181

8282

8383
class TestGetitem(base.BaseGetitemTests):
84+
skip_take = pytest.mark.skip(reason="GH-20664.")
85+
8486
@pytest.mark.skip(reason="Backwards compatibility")
8587
def test_getitem_scalar(self):
8688
# CategoricalDtype.type isn't "correct" since it should
8789
# be a parent of the elements (object). But don't want
8890
# to break things by changing.
8991
pass
9092

91-
@pytest.mark.xfail(reason="Categorical.take buggy")
93+
@skip_take
9294
def test_take(self):
9395
# TODO remove this once Categorical.take is fixed
9496
pass
9597

98+
@skip_take
99+
def test_take_negative(self):
100+
pass
101+
102+
@skip_take
103+
def test_take_pandas_style_negative_raises(self):
104+
pass
105+
106+
@skip_take
107+
def test_take_non_na_fill_value(self):
108+
pass
109+
96110
@pytest.mark.xfail(reason="Categorical.take buggy")
97111
def test_take_empty(self):
98112
pass

pandas/tests/extension/decimal/array.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
import pandas as pd
99
from pandas.core.arrays import ExtensionArray
1010
from pandas.core.dtypes.base import ExtensionDtype
11-
from pandas.core.dtypes.common import _ensure_platform_int
1211

1312

1413
class DecimalDtype(ExtensionDtype):
1514
type = decimal.Decimal
1615
name = 'decimal'
16+
na_value = decimal.Decimal('NaN')
1717

1818
@classmethod
1919
def construct_from_string(cls, string):
@@ -80,19 +80,8 @@ def nbytes(self):
8080
def isna(self):
8181
return np.array([x.is_nan() for x in self._data])
8282

83-
def take(self, indexer, allow_fill=True, fill_value=None):
84-
indexer = np.asarray(indexer)
85-
mask = indexer == -1
86-
87-
# take on empty array not handled as desired by numpy in case of -1
88-
if not len(self) and mask.all():
89-
return type(self)([self._na_value] * len(indexer))
90-
91-
indexer = _ensure_platform_int(indexer)
92-
out = self._data.take(indexer)
93-
out[mask] = self._na_value
94-
95-
return type(self)(out)
83+
def _values_for_take(self):
84+
return self.data
9685

9786
@property
9887
def _na_value(self):

0 commit comments

Comments
 (0)