diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index 0a066399e27ca..74940e8985ef3 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -914,6 +914,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your - Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`). - The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`) - Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`) +- :meth:`pandas.api.extensions.ExtensionArray.shift` added as part of the basic ``ExtensionArray`` interface (:issue:`22387`). - :meth:`~Series.shift` now dispatches to :meth:`ExtensionArray.shift` (:issue:`22386`) - :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`) - :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index eb2fef482ff17..b22bdf3a3f19b 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -418,13 +418,13 @@ def fillna(self, value=None, method=None, limit=None): return new_values def dropna(self): - """ Return ExtensionArray without NA values + """ + Return ExtensionArray without NA values Returns ------- valid : ExtensionArray """ - return self[~self.isna()] def shift(self, periods=1): @@ -446,13 +446,25 @@ def shift(self, periods=1): Returns ------- shifted : ExtensionArray + + Notes + ----- + If ``self`` is empty or ``periods`` is 0, a copy of ``self`` is + returned. + + If ``periods > len(self)``, then an array of size + len(self) is returned, with all values filled with + ``self.dtype.na_value``. """ # Note: this implementation assumes that `self.dtype.na_value` can be # stored in an instance of your ExtensionArray with `self.dtype`. - if periods == 0: + if not len(self) or periods == 0: return self.copy() - empty = self._from_sequence([self.dtype.na_value] * abs(periods), - dtype=self.dtype) + + empty = self._from_sequence( + [self.dtype.na_value] * min(abs(periods), len(self)), + dtype=self.dtype + ) if periods > 0: a = empty b = self[:-periods] @@ -462,7 +474,8 @@ def shift(self, periods=1): return self._concat_same_type([a, b]) def unique(self): - """Compute the ExtensionArray of unique values. + """ + Compute the ExtensionArray of unique values. Returns ------- diff --git a/pandas/core/arrays/sparse.py b/pandas/core/arrays/sparse.py index 9a5ef3b3a7dd0..b4b511cec3601 100644 --- a/pandas/core/arrays/sparse.py +++ b/pandas/core/arrays/sparse.py @@ -882,7 +882,7 @@ def fillna(self, value=None, method=None, limit=None): def shift(self, periods=1): - if periods == 0: + if not len(self) or periods == 0: return self.copy() subtype = np.result_type(np.nan, self.dtype.subtype) @@ -893,8 +893,11 @@ def shift(self, periods=1): else: arr = self - empty = self._from_sequence([self.dtype.na_value] * abs(periods), - dtype=arr.dtype) + empty = self._from_sequence( + [self.dtype.na_value] * min(abs(periods), len(self)), + dtype=arr.dtype + ) + if periods > 0: a = empty b = arr[:-periods] diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index e9a89c1af2f22..ace5c5346bf5a 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -189,6 +189,30 @@ def test_container_shift(self, data, frame, periods, indices): compare(result, expected) + @pytest.mark.parametrize('periods, indices', [ + [-4, [-1, -1]], + [-1, [1, -1]], + [0, [0, 1]], + [1, [-1, 0]], + [4, [-1, -1]] + ]) + def test_shift_non_empty_array(self, data, periods, indices): + # https://github.com/pandas-dev/pandas/issues/23911 + subset = data[:2] + result = subset.shift(periods) + expected = subset.take(indices, allow_fill=True) + self.assert_extension_array_equal(result, expected) + + @pytest.mark.parametrize('periods', [ + -4, -1, 0, 1, 4 + ]) + def test_shift_empty_array(self, data, periods): + # https://github.com/pandas-dev/pandas/issues/23911 + empty = data[:0] + result = empty.shift(periods) + expected = empty + self.assert_extension_array_equal(result, expected) + @pytest.mark.parametrize("as_frame", [True, False]) def test_hash_pandas_object_works(self, data, as_frame): # https://github.com/pandas-dev/pandas/issues/23066