Skip to content

Commit e7fca91

Browse files
TomAugspurgerjreback
authored andcommitted
Support NDFrame.shift with EAs (#22387)
1 parent dd24e76 commit e7fca91

File tree

4 files changed

+76
-4
lines changed

4 files changed

+76
-4
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ ExtensionType Changes
447447
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
448448
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
449449
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
450+
- :meth:`~Series.shift` now dispatches to :meth:`ExtensionArray.shift` (:issue:`22386`)
450451
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
451452
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
452453
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).

pandas/core/arrays/base.py

+38
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class ExtensionArray(object):
5959
* factorize / _values_for_factorize
6060
* argsort / _values_for_argsort
6161
62+
The remaining methods implemented on this class should be performant,
63+
as they only compose abstract methods. Still, a more efficient
64+
implementation may be available, and these methods can be overridden.
65+
6266
This class does not inherit from 'abc.ABCMeta' for performance reasons.
6367
Methods and properties required by the interface raise
6468
``pandas.errors.AbstractMethodError`` and no ``register`` method is
@@ -400,6 +404,40 @@ def dropna(self):
400404

401405
return self[~self.isna()]
402406

407+
def shift(self, periods=1):
408+
# type: (int) -> ExtensionArray
409+
"""
410+
Shift values by desired number.
411+
412+
Newly introduced missing values are filled with
413+
``self.dtype.na_value``.
414+
415+
.. versionadded:: 0.24.0
416+
417+
Parameters
418+
----------
419+
periods : int, default 1
420+
The number of periods to shift. Negative values are allowed
421+
for shifting backwards.
422+
423+
Returns
424+
-------
425+
shifted : ExtensionArray
426+
"""
427+
# Note: this implementation assumes that `self.dtype.na_value` can be
428+
# stored in an instance of your ExtensionArray with `self.dtype`.
429+
if periods == 0:
430+
return self.copy()
431+
empty = self._from_sequence([self.dtype.na_value] * abs(periods),
432+
dtype=self.dtype)
433+
if periods > 0:
434+
a = empty
435+
b = self[:-periods]
436+
else:
437+
a = self[abs(periods):]
438+
b = empty
439+
return self._concat_same_type([a, b])
440+
403441
def unique(self):
404442
"""Compute the ExtensionArray of unique values.
405443

pandas/core/internals/blocks.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,18 @@ def interpolate(self, method='pad', axis=0, inplace=False, limit=None,
20682068
limit=limit),
20692069
placement=self.mgr_locs)
20702070

2071+
def shift(self, periods, axis=0, mgr=None):
2072+
"""
2073+
Shift the block by `periods`.
2074+
2075+
Dispatches to underlying ExtensionArray and re-boxes in an
2076+
ExtensionBlock.
2077+
"""
2078+
# type: (int, Optional[BlockPlacement]) -> List[ExtensionBlock]
2079+
return [self.make_block_same_class(self.values.shift(periods=periods),
2080+
placement=self.mgr_locs,
2081+
ndim=self.ndim)]
2082+
20712083

20722084
class NumericBlock(Block):
20732085
__slots__ = ()
@@ -2691,10 +2703,6 @@ def _try_coerce_result(self, result):
26912703

26922704
return result
26932705

2694-
def shift(self, periods, axis=0, mgr=None):
2695-
return self.make_block_same_class(values=self.values.shift(periods),
2696-
placement=self.mgr_locs)
2697-
26982706
def to_dense(self):
26992707
# Categorical.get_values returns a DatetimeIndex for datetime
27002708
# categories, so we can't simply use `np.asarray(self.values)` like

pandas/tests/extension/base/methods.py

+25
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,28 @@ def test_combine_add(self, data_repeated):
138138
expected = pd.Series(
139139
orig_data1._from_sequence([a + val for a in list(orig_data1)]))
140140
self.assert_series_equal(result, expected)
141+
142+
@pytest.mark.parametrize('frame', [True, False])
143+
@pytest.mark.parametrize('periods, indices', [
144+
(-2, [2, 3, 4, -1, -1]),
145+
(0, [0, 1, 2, 3, 4]),
146+
(2, [-1, -1, 0, 1, 2]),
147+
])
148+
def test_container_shift(self, data, frame, periods, indices):
149+
# https://github.com/pandas-dev/pandas/issues/22386
150+
subset = data[:5]
151+
data = pd.Series(subset, name='A')
152+
expected = pd.Series(subset.take(indices, allow_fill=True), name='A')
153+
154+
if frame:
155+
result = data.to_frame(name='A').assign(B=1).shift(periods)
156+
expected = pd.concat([
157+
expected,
158+
pd.Series([1] * 5, name='B').shift(periods)
159+
], axis=1)
160+
compare = self.assert_frame_equal
161+
else:
162+
result = data.shift(periods)
163+
compare = self.assert_series_equal
164+
165+
compare(result, expected)

0 commit comments

Comments
 (0)