diff --git a/doc/source/whatsnew/v1.0.1.rst b/doc/source/whatsnew/v1.0.1.rst index e0182e4e3c1f2..f41eca9d6b846 100644 --- a/doc/source/whatsnew/v1.0.1.rst +++ b/doc/source/whatsnew/v1.0.1.rst @@ -56,6 +56,9 @@ Bug fixes - Plotting tz-aware timeseries no longer gives UserWarning (:issue:`31205`) +**Interval** + +- Bug in :meth:`Series.shift` with ``interval`` dtype raising a ``TypeError`` when shifting an interval array of integers or datetimes (:issue:`34195`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 23cf5f317ac7d..cc41ac1dc19a9 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -27,6 +27,7 @@ from pandas.core.dtypes.dtypes import IntervalDtype from pandas.core.dtypes.generic import ( ABCDatetimeIndex, + ABCExtensionArray, ABCIndexClass, ABCInterval, ABCIntervalIndex, @@ -789,6 +790,33 @@ def size(self) -> int: # Avoid materializing self.values return self.left.size + def shift(self, periods: int = 1, fill_value: object = None) -> ABCExtensionArray: + if not len(self) or periods == 0: + return self.copy() + + if isna(fill_value): + fill_value = self.dtype.na_value + + # ExtensionArray.shift doesn't work for two reasons + # 1. IntervalArray.dtype.na_value may not be correct for the dtype. + # 2. IntervalArray._from_sequence only accepts NaN for missing values, + # not other values like NaT + + empty_len = min(abs(periods), len(self)) + if isna(fill_value): + fill_value = self.left._na_value + empty = IntervalArray.from_breaks([fill_value] * (empty_len + 1)) + else: + empty = self._from_sequence([fill_value] * empty_len) + + if periods > 0: + a = empty + b = self[:-periods] + else: + a = self[abs(periods) :] + b = empty + return self._concat_same_type([a, b]) + def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs): """ Take elements from the IntervalArray. diff --git a/pandas/tests/arrays/interval/test_interval.py b/pandas/tests/arrays/interval/test_interval.py index e046d87780bb4..a43ea7e40a16a 100644 --- a/pandas/tests/arrays/interval/test_interval.py +++ b/pandas/tests/arrays/interval/test_interval.py @@ -81,6 +81,24 @@ def test_where_raises(self, other): with pytest.raises(ValueError, match=match): ser.where([True, False, True], other=other) + def test_shift(self): + # https://github.com/pandas-dev/pandas/issues/31495 + a = IntervalArray.from_breaks([1, 2, 3]) + result = a.shift() + # int -> float + expected = IntervalArray.from_tuples([(np.nan, np.nan), (1.0, 2.0)]) + tm.assert_interval_array_equal(result, expected) + + def test_shift_datetime(self): + a = IntervalArray.from_breaks(pd.date_range("2000", periods=4)) + result = a.shift(2) + expected = a.take([-1, -1, 0], allow_fill=True) + tm.assert_interval_array_equal(result, expected) + + result = a.shift(-1) + expected = a.take([1, 2, -1], allow_fill=True) + tm.assert_interval_array_equal(result, expected) + class TestSetitem: def test_set_na(self, left_right_dtypes): diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index 24ab7fe3fc845..6ed8b782deffa 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -280,6 +280,13 @@ def test_shift_empty_array(self, data, periods): expected = empty self.assert_extension_array_equal(result, expected) + def test_shift_zero_copies(self, data): + result = data.shift(0) + assert result is not data + + result = data[:0].shift(2) + assert result is not data + def test_shift_fill_value(self, data): arr = data[:4] fill_value = data[0]