Skip to content

Commit d012623

Browse files
authored
REF: simplify PeriodIndex._shallow_copy (#32280)
1 parent f2a1325 commit d012623

File tree

8 files changed

+42
-20
lines changed

8 files changed

+42
-20
lines changed

pandas/core/indexes/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -4236,6 +4236,9 @@ def putmask(self, mask, value):
42364236
values = self.values.copy()
42374237
try:
42384238
np.putmask(values, mask, self._convert_for_op(value))
4239+
if is_period_dtype(self.dtype):
4240+
# .values cast to object, so we need to cast back
4241+
values = type(self)(values)._data
42394242
return self._shallow_copy(values)
42404243
except (ValueError, TypeError) as err:
42414244
if is_object_dtype(self):

pandas/core/indexes/datetimelike.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,8 @@ def where(self, cond, other=None):
520520
other = other.view("i8")
521521

522522
result = np.where(cond, values, other).astype("i8")
523-
return self._shallow_copy(result)
523+
arr = type(self._data)._simple_new(result, dtype=self.dtype)
524+
return type(self)._simple_new(arr, name=self.name)
524525

525526
def _summary(self, name=None) -> str:
526527
"""

pandas/core/indexes/period.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -250,22 +250,11 @@ def _has_complex_internals(self):
250250
return True
251251

252252
def _shallow_copy(self, values=None, name: Label = no_default):
253-
# TODO: simplify, figure out type of values
254253
name = name if name is not no_default else self.name
255254

256255
if values is None:
257256
values = self._data
258257

259-
if isinstance(values, type(self)):
260-
values = values._data
261-
262-
if not isinstance(values, PeriodArray):
263-
if isinstance(values, np.ndarray) and values.dtype == "i8":
264-
values = PeriodArray(values, freq=self.freq)
265-
else:
266-
# GH#30713 this should never be reached
267-
raise TypeError(type(values), getattr(values, "dtype", None))
268-
269258
return self._simple_new(values, name=name)
270259

271260
def _maybe_convert_timedelta(self, other):
@@ -618,10 +607,11 @@ def insert(self, loc, item):
618607
if not isinstance(item, Period) or self.freq != item.freq:
619608
return self.astype(object).insert(loc, item)
620609

621-
idx = np.concatenate(
610+
i8result = np.concatenate(
622611
(self[:loc].asi8, np.array([item.ordinal]), self[loc:].asi8)
623612
)
624-
return self._shallow_copy(idx)
613+
arr = type(self._data)._simple_new(i8result, dtype=self.dtype)
614+
return type(self)._simple_new(arr, name=self.name)
625615

626616
def join(self, other, how="left", level=None, return_indexers=False, sort=False):
627617
"""

pandas/tests/base/test_ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
is_datetime64_dtype,
1515
is_datetime64tz_dtype,
1616
is_object_dtype,
17+
is_period_dtype,
1718
needs_i8_conversion,
1819
)
1920

@@ -295,6 +296,10 @@ def test_value_counts_unique_nunique_null(self, null_obj, index_or_series_obj):
295296
obj[0:2] = pd.NaT
296297
values = obj._values
297298

299+
elif is_period_dtype(obj):
300+
values[0:2] = iNaT
301+
parr = type(obj._data)(values, dtype=obj.dtype)
302+
values = obj._shallow_copy(parr)
298303
elif needs_i8_conversion(obj):
299304
values[0:2] = iNaT
300305
values = obj._shallow_copy(values)

pandas/tests/indexes/datetimes/test_indexing.py

+8
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ def test_dti_custom_getitem_matplotlib_hackaround(self):
121121

122122

123123
class TestWhere:
124+
def test_where_doesnt_retain_freq(self):
125+
dti = date_range("20130101", periods=3, freq="D", name="idx")
126+
cond = [True, True, False]
127+
expected = DatetimeIndex([dti[0], dti[1], dti[0]], freq=None, name="idx")
128+
129+
result = dti.where(cond, dti[::-1])
130+
tm.assert_index_equal(result, expected)
131+
124132
def test_where_other(self):
125133
# other is ndarray or Index
126134
i = pd.date_range("20130101", periods=3, tz="US/Eastern")

pandas/tests/indexes/period/test_period.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,23 @@ def test_make_time_series(self):
117117
assert isinstance(series, Series)
118118

119119
def test_shallow_copy_empty(self):
120-
121120
# GH13067
122121
idx = PeriodIndex([], freq="M")
123122
result = idx._shallow_copy()
124123
expected = idx
125124

126125
tm.assert_index_equal(result, expected)
127126

128-
def test_shallow_copy_i8(self):
127+
def test_shallow_copy_disallow_i8(self):
129128
# GH-24391
130129
pi = period_range("2018-01-01", periods=3, freq="2D")
131-
result = pi._shallow_copy(pi.asi8)
132-
tm.assert_index_equal(result, pi)
130+
with pytest.raises(AssertionError, match="ndarray"):
131+
pi._shallow_copy(pi.asi8)
132+
133+
def test_shallow_copy_requires_disallow_period_index(self):
134+
pi = period_range("2018-01-01", periods=3, freq="2D")
135+
with pytest.raises(AssertionError, match="PeriodIndex"):
136+
pi._shallow_copy(pi)
133137

134138
def test_view_asi8(self):
135139
idx = PeriodIndex([], freq="M")

pandas/tests/indexes/test_common.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pandas._libs.tslibs import iNaT
1212

13-
from pandas.core.dtypes.common import needs_i8_conversion
13+
from pandas.core.dtypes.common import is_period_dtype, needs_i8_conversion
1414

1515
import pandas as pd
1616
from pandas import CategoricalIndex, MultiIndex, RangeIndex
@@ -219,7 +219,10 @@ def test_get_unique_index(self, indices):
219219
if not indices._can_hold_na:
220220
pytest.skip("Skip na-check if index cannot hold na")
221221

222-
if needs_i8_conversion(indices):
222+
if is_period_dtype(indices):
223+
vals = indices[[0] * 5]._data
224+
vals[0] = pd.NaT
225+
elif needs_i8_conversion(indices):
223226
vals = indices.asi8[[0] * 5]
224227
vals[0] = iNaT
225228
else:

pandas/tests/indexes/timedeltas/test_indexing.py

+8
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ def test_timestamp_invalid_key(self, key):
6666

6767

6868
class TestWhere:
69+
def test_where_doesnt_retain_freq(self):
70+
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
71+
cond = [True, True, False]
72+
expected = TimedeltaIndex([tdi[0], tdi[1], tdi[0]], freq=None, name="idx")
73+
74+
result = tdi.where(cond, tdi[::-1])
75+
tm.assert_index_equal(result, expected)
76+
6977
def test_where_invalid_dtypes(self):
7078
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
7179

0 commit comments

Comments
 (0)