Skip to content

Commit 823adcf

Browse files
varadgunjalTomAugspurger
authored andcommitted
Fixing shift() for ExtensionArray (#23947)
* Fixing shift() for ExtensionArray #23911
1 parent a102b0c commit 823adcf

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
990990
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
991991
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
992992
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
993+
- :meth:`pandas.api.extensions.ExtensionArray.shift` added as part of the basic ``ExtensionArray`` interface (:issue:`22387`).
993994
- :meth:`~Series.shift` now dispatches to :meth:`ExtensionArray.shift` (:issue:`22386`)
994995
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
995996
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)

pandas/core/arrays/base.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,6 @@ def dropna(self):
442442
-------
443443
valid : ExtensionArray
444444
"""
445-
446445
return self[~self.isna()]
447446

448447
def shift(self, periods=1):
@@ -464,13 +463,25 @@ def shift(self, periods=1):
464463
Returns
465464
-------
466465
shifted : ExtensionArray
466+
467+
Notes
468+
-----
469+
If ``self`` is empty or ``periods`` is 0, a copy of ``self`` is
470+
returned.
471+
472+
If ``periods > len(self)``, then an array of size
473+
len(self) is returned, with all values filled with
474+
``self.dtype.na_value``.
467475
"""
468476
# Note: this implementation assumes that `self.dtype.na_value` can be
469477
# stored in an instance of your ExtensionArray with `self.dtype`.
470-
if periods == 0:
478+
if not len(self) or periods == 0:
471479
return self.copy()
472-
empty = self._from_sequence([self.dtype.na_value] * abs(periods),
473-
dtype=self.dtype)
480+
481+
empty = self._from_sequence(
482+
[self.dtype.na_value] * min(abs(periods), len(self)),
483+
dtype=self.dtype
484+
)
474485
if periods > 0:
475486
a = empty
476487
b = self[:-periods]

pandas/core/arrays/sparse.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ def fillna(self, value=None, method=None, limit=None):
889889

890890
def shift(self, periods=1):
891891

892-
if periods == 0:
892+
if not len(self) or periods == 0:
893893
return self.copy()
894894

895895
subtype = np.result_type(np.nan, self.dtype.subtype)
@@ -900,8 +900,11 @@ def shift(self, periods=1):
900900
else:
901901
arr = self
902902

903-
empty = self._from_sequence([self.dtype.na_value] * abs(periods),
904-
dtype=arr.dtype)
903+
empty = self._from_sequence(
904+
[self.dtype.na_value] * min(abs(periods), len(self)),
905+
dtype=arr.dtype
906+
)
907+
905908
if periods > 0:
906909
a = empty
907910
b = arr[:-periods]

pandas/tests/extension/base/methods.py

+24
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,30 @@ def test_container_shift(self, data, frame, periods, indices):
198198

199199
compare(result, expected)
200200

201+
@pytest.mark.parametrize('periods, indices', [
202+
[-4, [-1, -1]],
203+
[-1, [1, -1]],
204+
[0, [0, 1]],
205+
[1, [-1, 0]],
206+
[4, [-1, -1]]
207+
])
208+
def test_shift_non_empty_array(self, data, periods, indices):
209+
# https://github.com/pandas-dev/pandas/issues/23911
210+
subset = data[:2]
211+
result = subset.shift(periods)
212+
expected = subset.take(indices, allow_fill=True)
213+
self.assert_extension_array_equal(result, expected)
214+
215+
@pytest.mark.parametrize('periods', [
216+
-4, -1, 0, 1, 4
217+
])
218+
def test_shift_empty_array(self, data, periods):
219+
# https://github.com/pandas-dev/pandas/issues/23911
220+
empty = data[:0]
221+
result = empty.shift(periods)
222+
expected = empty
223+
self.assert_extension_array_equal(result, expected)
224+
201225
@pytest.mark.parametrize("as_frame", [True, False])
202226
def test_hash_pandas_object_works(self, data, as_frame):
203227
# https://github.com/pandas-dev/pandas/issues/23066

0 commit comments

Comments
 (0)