diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c281bd80cb274..d49a0d799526a 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -458,11 +458,23 @@ def take(self, indexer, allow_fill=True, fill_value=None): Fill value to replace -1 values with. If applicable, this should use the sentinel missing value for this type. + Returns + ------- + ExtensionArray + + Raises + ------ + IndexError + When the indexer is out of bounds for the array. + Notes ----- This should follow pandas' semantics where -1 indicates missing values. Positions where indexer is ``-1`` should be filled with the missing value for this type. + This gives rise to the special case of a take on an empty + ExtensionArray that does not raises an IndexError straight away + when the `indexer` is all ``-1``. This is called by ``Series.__getitem__``, ``.loc``, ``iloc``, when the indexer is a sequence of values. @@ -477,6 +489,12 @@ def take(self, indexer, allow_fill=True, fill_value=None): def take(self, indexer, allow_fill=True, fill_value=None): indexer = np.asarray(indexer) mask = indexer == -1 + + # take on empty array not handled as desired by numpy + # in case of -1 (all missing take) + if not len(self) and mask.all(): + return type(self)([np.nan] * len(indexer)) + result = self.data.take(indexer) result[mask] = np.nan # NA for this type return type(self)(result) diff --git a/pandas/tests/extension/base/getitem.py b/pandas/tests/extension/base/getitem.py index 566ba1721d13c..4e2a65eba06dc 100644 --- a/pandas/tests/extension/base/getitem.py +++ b/pandas/tests/extension/base/getitem.py @@ -1,6 +1,8 @@ +import pytest import numpy as np import pandas as pd +import pandas.util.testing as tm from .base import BaseExtensionTests @@ -120,3 +122,48 @@ def test_take_sequence(self, data): assert result.iloc[0] == data[0] assert result.iloc[1] == data[1] assert result.iloc[2] == data[3] + + def test_take(self, data, na_value, na_cmp): + result = data.take([0, -1]) + assert result.dtype == data.dtype + assert result[0] == data[0] + na_cmp(result[1], na_value) + + with tm.assert_raises_regex(IndexError, "out of bounds"): + data.take([len(data) + 1]) + + def test_take_empty(self, data, na_value, na_cmp): + empty = data[:0] + result = empty.take([-1]) + na_cmp(result[0], na_value) + + with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"): + empty.take([0, 1]) + + @pytest.mark.xfail(reason="Series.take with extension array buggy for -1") + def test_take_series(self, data): + s = pd.Series(data) + result = s.take([0, -1]) + expected = pd.Series( + data._constructor_from_sequence([data[0], data[len(data) - 1]]), + index=[0, len(data) - 1]) + self.assert_series_equal(result, expected) + + def test_reindex(self, data, na_value): + s = pd.Series(data) + result = s.reindex([0, 1, 3]) + expected = pd.Series(data.take([0, 1, 3]), index=[0, 1, 3]) + self.assert_series_equal(result, expected) + + n = len(data) + result = s.reindex([-1, 0, n]) + expected = pd.Series( + data._constructor_from_sequence([na_value, data[0], na_value]), + index=[-1, 0, n]) + self.assert_series_equal(result, expected) + + result = s.reindex([n, n + 1]) + expected = pd.Series( + data._constructor_from_sequence([na_value, na_value]), + index=[n, n + 1]) + self.assert_series_equal(result, expected) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 6abf1f7f9a65a..27c156c15203f 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -84,6 +84,19 @@ def test_getitem_scalar(self): # to break things by changing. pass + @pytest.mark.xfail(reason="Categorical.take buggy") + def test_take(self): + # TODO remove this once Categorical.take is fixed + pass + + @pytest.mark.xfail(reason="Categorical.take buggy") + def test_take_empty(self): + pass + + @pytest.mark.xfail(reason="test not written correctly for categorical") + def test_reindex(self): + pass + class TestSetitem(base.BaseSetitemTests): pass diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index f93d11f579f11..a8e88365b5648 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -81,6 +81,10 @@ def take(self, indexer, allow_fill=True, fill_value=None): indexer = np.asarray(indexer) mask = indexer == -1 + # take on empty array not handled as desired by numpy in case of -1 + if not len(self) and mask.all(): + return type(self)([self._na_value] * len(indexer)) + indexer = _ensure_platform_int(indexer) out = self.values.take(indexer) out[mask] = self._na_value diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index d9ae49d87804a..33843492cb706 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -89,8 +89,12 @@ def isna(self): return np.array([x == self._na_value for x in self.data]) def take(self, indexer, allow_fill=True, fill_value=None): - output = [self.data[loc] if loc != -1 else self._na_value - for loc in indexer] + try: + output = [self.data[loc] if loc != -1 else self._na_value + for loc in indexer] + except IndexError: + raise IndexError("Index is out of bounds or cannot do a " + "non-empty take from an empty array.") return self._constructor_from_sequence(output) def copy(self, deep=False):