Skip to content

Commit d28db65

Browse files
BUG: Fixed IntervalArray[int].shift (#31502)
1 parent 01582c4 commit d28db65

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

doc/source/whatsnew/v1.0.1.rst

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ Bug fixes
5757

5858
- Plotting tz-aware timeseries no longer gives UserWarning (:issue:`31205`)
5959

60+
**Interval**
61+
62+
- Bug in :meth:`Series.shift` with ``interval`` dtype raising a ``TypeError`` when shifting an interval array of integers or datetimes (:issue:`34195`)
6063

6164
.. ---------------------------------------------------------------------------
6265

pandas/core/arrays/interval.py

+28
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pandas.core.dtypes.dtypes import IntervalDtype
2828
from pandas.core.dtypes.generic import (
2929
ABCDatetimeIndex,
30+
ABCExtensionArray,
3031
ABCIndexClass,
3132
ABCInterval,
3233
ABCIntervalIndex,
@@ -789,6 +790,33 @@ def size(self) -> int:
789790
# Avoid materializing self.values
790791
return self.left.size
791792

793+
def shift(self, periods: int = 1, fill_value: object = None) -> ABCExtensionArray:
794+
if not len(self) or periods == 0:
795+
return self.copy()
796+
797+
if isna(fill_value):
798+
fill_value = self.dtype.na_value
799+
800+
# ExtensionArray.shift doesn't work for two reasons
801+
# 1. IntervalArray.dtype.na_value may not be correct for the dtype.
802+
# 2. IntervalArray._from_sequence only accepts NaN for missing values,
803+
# not other values like NaT
804+
805+
empty_len = min(abs(periods), len(self))
806+
if isna(fill_value):
807+
fill_value = self.left._na_value
808+
empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1))
809+
else:
810+
empty = self._from_sequence([fill_value] * empty_len)
811+
812+
if periods > 0:
813+
a = empty
814+
b = self[:-periods]
815+
else:
816+
a = self[abs(periods) :]
817+
b = empty
818+
return self._concat_same_type([a, b])
819+
792820
def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
793821
"""
794822
Take elements from the IntervalArray.

pandas/tests/arrays/interval/test_interval.py

+18
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,24 @@ def test_where_raises(self, other):
8181
with pytest.raises(ValueError, match=match):
8282
ser.where([True, False, True], other=other)
8383

84+
def test_shift(self):
85+
# https://github.com/pandas-dev/pandas/issues/31495
86+
a = IntervalArray.from_breaks([1, 2, 3])
87+
result = a.shift()
88+
# int -> float
89+
expected = IntervalArray.from_tuples([(np.nan, np.nan), (1.0, 2.0)])
90+
tm.assert_interval_array_equal(result, expected)
91+
92+
def test_shift_datetime(self):
93+
a = IntervalArray.from_breaks(pd.date_range("2000", periods=4))
94+
result = a.shift(2)
95+
expected = a.take([-1, -1, 0], allow_fill=True)
96+
tm.assert_interval_array_equal(result, expected)
97+
98+
result = a.shift(-1)
99+
expected = a.take([1, 2, -1], allow_fill=True)
100+
tm.assert_interval_array_equal(result, expected)
101+
84102

85103
class TestSetitem:
86104
def test_set_na(self, left_right_dtypes):

pandas/tests/extension/base/methods.py

+7
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,13 @@ def test_shift_empty_array(self, data, periods):
280280
expected = empty
281281
self.assert_extension_array_equal(result, expected)
282282

283+
def test_shift_zero_copies(self, data):
284+
result = data.shift(0)
285+
assert result is not data
286+
287+
result = data[:0].shift(2)
288+
assert result is not data
289+
283290
def test_shift_fill_value(self, data):
284291
arr = data[:4]
285292
fill_value = data[0]

0 commit comments

Comments
 (0)