diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index b5e323fbd0fa4..a1681322a4a5f 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4234,6 +4234,9 @@ def putmask(self, mask, value): values = self.values.copy() try: np.putmask(values, mask, self._convert_for_op(value)) + if is_period_dtype(self.dtype): + # .values cast to object, so we need to cast back + values = type(self)(values)._data return self._shallow_copy(values) except (ValueError, TypeError) as err: if is_object_dtype(self): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 72a2aba2d8a88..f9ce8eb6d720d 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -520,7 +520,8 @@ def where(self, cond, other=None): other = other.view("i8") result = np.where(cond, values, other).astype("i8") - return self._shallow_copy(result) + arr = type(self._data)._simple_new(result, dtype=self.dtype) + return type(self)._simple_new(arr, name=self.name) def _summary(self, name=None) -> str: """ diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 35a5d99abf4e6..017f104e18493 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -250,22 +250,11 @@ def _has_complex_internals(self): return True def _shallow_copy(self, values=None, name: Label = no_default): - # TODO: simplify, figure out type of values name = name if name is not no_default else self.name if values is None: values = self._data - if isinstance(values, type(self)): - values = values._data - - if not isinstance(values, PeriodArray): - if isinstance(values, np.ndarray) and values.dtype == "i8": - values = PeriodArray(values, freq=self.freq) - else: - # GH#30713 this should never be reached - raise TypeError(type(values), getattr(values, "dtype", None)) - return self._simple_new(values, name=name) def _maybe_convert_timedelta(self, other): @@ -618,10 +607,11 @@ def insert(self, loc, item): if not isinstance(item, Period) or self.freq != item.freq: return self.astype(object).insert(loc, item) - idx = np.concatenate( + i8result = np.concatenate( (self[:loc].asi8, np.array([item.ordinal]), self[loc:].asi8) ) - return self._shallow_copy(idx) + arr = type(self._data)._simple_new(i8result, dtype=self.dtype) + return type(self)._simple_new(arr, name=self.name) def join(self, other, how="left", level=None, return_indexers=False, sort=False): """ diff --git a/pandas/tests/base/test_ops.py b/pandas/tests/base/test_ops.py index f85d823cb2fac..298955654a89f 100644 --- a/pandas/tests/base/test_ops.py +++ b/pandas/tests/base/test_ops.py @@ -14,6 +14,7 @@ is_datetime64_dtype, is_datetime64tz_dtype, is_object_dtype, + is_period_dtype, needs_i8_conversion, ) @@ -295,6 +296,10 @@ def test_value_counts_unique_nunique_null(self, null_obj, index_or_series_obj): obj[0:2] = pd.NaT values = obj._values + elif is_period_dtype(obj): + values[0:2] = iNaT + parr = type(obj._data)(values, dtype=obj.dtype) + values = obj._shallow_copy(parr) elif needs_i8_conversion(obj): values[0:2] = iNaT values = obj._shallow_copy(values) diff --git a/pandas/tests/indexes/datetimes/test_indexing.py b/pandas/tests/indexes/datetimes/test_indexing.py index ceab670fb5041..554ae76979ba8 100644 --- a/pandas/tests/indexes/datetimes/test_indexing.py +++ b/pandas/tests/indexes/datetimes/test_indexing.py @@ -121,6 +121,14 @@ def test_dti_custom_getitem_matplotlib_hackaround(self): class TestWhere: + def test_where_doesnt_retain_freq(self): + dti = date_range("20130101", periods=3, freq="D", name="idx") + cond = [True, True, False] + expected = DatetimeIndex([dti[0], dti[1], dti[0]], freq=None, name="idx") + + result = dti.where(cond, dti[::-1]) + tm.assert_index_equal(result, expected) + def test_where_other(self): # other is ndarray or Index i = pd.date_range("20130101", periods=3, tz="US/Eastern") diff --git a/pandas/tests/indexes/period/test_period.py b/pandas/tests/indexes/period/test_period.py index 40c7ffba46450..ab3e967f12360 100644 --- a/pandas/tests/indexes/period/test_period.py +++ b/pandas/tests/indexes/period/test_period.py @@ -117,7 +117,6 @@ def test_make_time_series(self): assert isinstance(series, Series) def test_shallow_copy_empty(self): - # GH13067 idx = PeriodIndex([], freq="M") result = idx._shallow_copy() @@ -125,11 +124,16 @@ def test_shallow_copy_empty(self): tm.assert_index_equal(result, expected) - def test_shallow_copy_i8(self): + def test_shallow_copy_disallow_i8(self): # GH-24391 pi = period_range("2018-01-01", periods=3, freq="2D") - result = pi._shallow_copy(pi.asi8) - tm.assert_index_equal(result, pi) + with pytest.raises(AssertionError, match="ndarray"): + pi._shallow_copy(pi.asi8) + + def test_shallow_copy_requires_disallow_period_index(self): + pi = period_range("2018-01-01", periods=3, freq="2D") + with pytest.raises(AssertionError, match="PeriodIndex"): + pi._shallow_copy(pi) def test_view_asi8(self): idx = PeriodIndex([], freq="M") diff --git a/pandas/tests/indexes/test_common.py b/pandas/tests/indexes/test_common.py index b46e6514b4536..c6ba5c9d61e9e 100644 --- a/pandas/tests/indexes/test_common.py +++ b/pandas/tests/indexes/test_common.py @@ -10,7 +10,7 @@ from pandas._libs.tslibs import iNaT -from pandas.core.dtypes.common import needs_i8_conversion +from pandas.core.dtypes.common import is_period_dtype, needs_i8_conversion import pandas as pd from pandas import CategoricalIndex, MultiIndex, RangeIndex @@ -219,7 +219,10 @@ def test_get_unique_index(self, indices): if not indices._can_hold_na: pytest.skip("Skip na-check if index cannot hold na") - if needs_i8_conversion(indices): + if is_period_dtype(indices): + vals = indices[[0] * 5]._data + vals[0] = pd.NaT + elif needs_i8_conversion(indices): vals = indices.asi8[[0] * 5] vals[0] = iNaT else: diff --git a/pandas/tests/indexes/timedeltas/test_indexing.py b/pandas/tests/indexes/timedeltas/test_indexing.py index 14fff6f9c85b5..5dec799832291 100644 --- a/pandas/tests/indexes/timedeltas/test_indexing.py +++ b/pandas/tests/indexes/timedeltas/test_indexing.py @@ -66,6 +66,14 @@ def test_timestamp_invalid_key(self, key): class TestWhere: + def test_where_doesnt_retain_freq(self): + tdi = timedelta_range("1 day", periods=3, freq="D", name="idx") + cond = [True, True, False] + expected = TimedeltaIndex([tdi[0], tdi[1], tdi[0]], freq=None, name="idx") + + result = tdi.where(cond, tdi[::-1]) + tm.assert_index_equal(result, expected) + def test_where_invalid_dtypes(self): tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")