Skip to content

Commit a135e4a

Browse files
jbrockmendelsimonjayhawkins
authored andcommitted
Backport PR pandas-dev#38120: API: preserve freq in DTI/TDI.factorize
1 parent 8a2b8e2 commit a135e4a

File tree

8 files changed

+100
-20
lines changed

8 files changed

+100
-20
lines changed

doc/source/whatsnew/v1.1.5.rst

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Fixed regressions
1919
- Fixed regression in :meth:`DataFrame.loc` and :meth:`Series.loc` for ``__setitem__`` when one-dimensional tuple was given to select from :class:`MultiIndex` (:issue:`37711`)
2020
- Fixed regression in inplace operations on :class:`Series` with ``ExtensionDtype`` with NumPy dtyped operand (:issue:`37910`)
2121
- Fixed regression in metadata propagation for ``groupby`` iterator (:issue:`37343`)
22+
- Fixed regression in :class:`MultiIndex` constructed from a :class:`DatetimeIndex` not retaining frequency (:issue:`35563`)
2223
- Fixed regression in indexing on a :class:`Series` with ``CategoricalDtype`` after unpickling (:issue:`37631`)
2324
- Fixed regression in :meth:`DataFrame.groupby` aggregation with out-of-bounds datetime objects in an object-dtype column (:issue:`36003`)
2425
- Fixed regression in ``df.groupby(..).rolling(..)`` with the resulting :class:`MultiIndex` when grouping by a label that is in the index (:issue:`37641`)

pandas/core/algorithms.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@
4646
pandas_dtype,
4747
)
4848
from pandas.core.dtypes.generic import (
49+
ABCDatetimeArray,
4950
ABCExtensionArray,
5051
ABCIndex,
5152
ABCIndexClass,
5253
ABCMultiIndex,
5354
ABCSeries,
55+
ABCTimedeltaArray,
5456
)
5557
from pandas.core.dtypes.missing import isna, na_value_for_dtype
5658

@@ -191,8 +193,16 @@ def _reconstruct_data(
191193
-------
192194
ExtensionArray or np.ndarray
193195
"""
196+
if isinstance(values, ABCExtensionArray) and values.dtype == dtype:
197+
# Catch DatetimeArray/TimedeltaArray
198+
return values
199+
194200
if is_extension_array_dtype(dtype):
195-
values = dtype.construct_array_type()._from_sequence(values)
201+
cls = dtype.construct_array_type()
202+
if isinstance(values, cls) and values.dtype == dtype:
203+
return values
204+
205+
values = cls._from_sequence(values)
196206
elif is_bool_dtype(dtype):
197207
values = values.astype(dtype, copy=False)
198208

@@ -654,6 +664,8 @@ def factorize(
654664

655665
values = _ensure_arraylike(values)
656666
original = values
667+
if not isinstance(values, ABCMultiIndex):
668+
values = extract_array(values, extract_numpy=True)
657669

658670
# GH35667, if na_sentinel=None, we will not dropna NaNs from the uniques
659671
# of values, assign na_sentinel=-1 to replace code value for NaN.
@@ -662,8 +674,20 @@ def factorize(
662674
na_sentinel = -1
663675
dropna = False
664676

677+
if (
678+
isinstance(values, (ABCDatetimeArray, ABCTimedeltaArray))
679+
and values.freq is not None
680+
):
681+
codes, uniques = values.factorize(sort=sort)
682+
if isinstance(original, ABCIndexClass):
683+
uniques = original._shallow_copy(uniques, name=None)
684+
elif isinstance(original, ABCSeries):
685+
from pandas import Index
686+
687+
uniques = Index(uniques)
688+
return codes, uniques
689+
665690
if is_extension_array_dtype(values.dtype):
666-
values = extract_array(values)
667691
codes, uniques = values.factorize(na_sentinel=na_sentinel)
668692
dtype = original.dtype
669693
else:

pandas/core/arrays/datetimelike.py

+14
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,20 @@ def mean(self, skipna=True):
16601660
# Don't have to worry about NA `result`, since no NA went in.
16611661
return self._box_func(result)
16621662

1663+
# --------------------------------------------------------------
1664+
1665+
def factorize(self, na_sentinel=-1, sort: bool = False):
1666+
if self.freq is not None:
1667+
# We must be unique, so can short-circuit (and retain freq)
1668+
codes = np.arange(len(self), dtype=np.intp)
1669+
uniques = self.copy() # TODO: copy or view?
1670+
if sort and self.freq.n < 0:
1671+
codes = codes[::-1]
1672+
uniques = uniques[::-1]
1673+
return codes, uniques
1674+
# FIXME: shouldn't get here; we are ignoring sort
1675+
return super().factorize(na_sentinel=na_sentinel)
1676+
16631677

16641678
DatetimeLikeArrayMixin._add_comparison_ops()
16651679

pandas/core/arrays/period.py

+4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
import pandas.core.algorithms as algos
5050
from pandas.core.arrays import datetimelike as dtl
51+
from pandas.core.arrays.base import ExtensionArray
5152
import pandas.core.common as com
5253

5354

@@ -766,6 +767,9 @@ def _check_timedeltalike_freq_compat(self, other):
766767

767768
raise raise_on_incompatible(self, other)
768769

770+
def factorize(self, na_sentinel=-1):
771+
return ExtensionArray.factorize(self, na_sentinel=na_sentinel)
772+
769773

770774
def raise_on_incompatible(left, right):
771775
"""

pandas/tests/indexes/datetimes/test_datetime.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,12 @@ def test_factorize(self):
271271
arr, idx = idx1.factorize()
272272
tm.assert_numpy_array_equal(arr, exp_arr)
273273
tm.assert_index_equal(idx, exp_idx)
274+
assert idx.freq == exp_idx.freq
274275

275276
arr, idx = idx1.factorize(sort=True)
276277
tm.assert_numpy_array_equal(arr, exp_arr)
277278
tm.assert_index_equal(idx, exp_idx)
279+
assert idx.freq == exp_idx.freq
278280

279281
# tz must be preserved
280282
idx1 = idx1.tz_localize("Asia/Tokyo")
@@ -283,6 +285,7 @@ def test_factorize(self):
283285
arr, idx = idx1.factorize()
284286
tm.assert_numpy_array_equal(arr, exp_arr)
285287
tm.assert_index_equal(idx, exp_idx)
288+
assert idx.freq == exp_idx.freq
286289

287290
idx2 = pd.DatetimeIndex(
288291
["2014-03", "2014-03", "2014-02", "2014-01", "2014-03", "2014-01"]
@@ -293,49 +296,65 @@ def test_factorize(self):
293296
arr, idx = idx2.factorize(sort=True)
294297
tm.assert_numpy_array_equal(arr, exp_arr)
295298
tm.assert_index_equal(idx, exp_idx)
299+
assert idx.freq == exp_idx.freq
296300

297301
exp_arr = np.array([0, 0, 1, 2, 0, 2], dtype=np.intp)
298302
exp_idx = DatetimeIndex(["2014-03", "2014-02", "2014-01"])
299303
arr, idx = idx2.factorize()
300304
tm.assert_numpy_array_equal(arr, exp_arr)
301305
tm.assert_index_equal(idx, exp_idx)
306+
assert idx.freq == exp_idx.freq
302307

303-
# freq must be preserved
308+
def test_factorize_preserves_freq(self):
309+
# GH#38120 freq should be preserved
304310
idx3 = date_range("2000-01", periods=4, freq="M", tz="Asia/Tokyo")
305311
exp_arr = np.array([0, 1, 2, 3], dtype=np.intp)
312+
306313
arr, idx = idx3.factorize()
307314
tm.assert_numpy_array_equal(arr, exp_arr)
308315
tm.assert_index_equal(idx, idx3)
316+
assert idx.freq == idx3.freq
317+
318+
arr, idx = pd.factorize(idx3)
319+
tm.assert_numpy_array_equal(arr, exp_arr)
320+
tm.assert_index_equal(idx, idx3)
321+
assert idx.freq == idx3.freq
309322

310-
def test_factorize_tz(self, tz_naive_fixture):
323+
def test_factorize_tz(self, tz_naive_fixture, index_or_series):
311324
tz = tz_naive_fixture
312325
# GH#13750
313326
base = pd.date_range("2016-11-05", freq="H", periods=100, tz=tz)
314327
idx = base.repeat(5)
315328

316329
exp_arr = np.arange(100, dtype=np.intp).repeat(5)
317330

318-
for obj in [idx, pd.Series(idx)]:
319-
arr, res = obj.factorize()
320-
tm.assert_numpy_array_equal(arr, exp_arr)
321-
expected = base._with_freq(None)
322-
tm.assert_index_equal(res, expected)
331+
obj = index_or_series(idx)
332+
333+
arr, res = obj.factorize()
334+
tm.assert_numpy_array_equal(arr, exp_arr)
335+
expected = base._with_freq(None)
336+
tm.assert_index_equal(res, expected)
337+
assert res.freq == expected.freq
323338

324-
def test_factorize_dst(self):
339+
def test_factorize_dst(self, index_or_series):
325340
# GH 13750
326341
idx = pd.date_range("2016-11-06", freq="H", periods=12, tz="US/Eastern")
342+
obj = index_or_series(idx)
327343

328-
for obj in [idx, pd.Series(idx)]:
329-
arr, res = obj.factorize()
330-
tm.assert_numpy_array_equal(arr, np.arange(12, dtype=np.intp))
331-
tm.assert_index_equal(res, idx)
344+
arr, res = obj.factorize()
345+
tm.assert_numpy_array_equal(arr, np.arange(12, dtype=np.intp))
346+
tm.assert_index_equal(res, idx)
347+
if index_or_series is Index:
348+
assert res.freq == idx.freq
332349

333350
idx = pd.date_range("2016-06-13", freq="H", periods=12, tz="US/Eastern")
351+
obj = index_or_series(idx)
334352

335-
for obj in [idx, pd.Series(idx)]:
336-
arr, res = obj.factorize()
337-
tm.assert_numpy_array_equal(arr, np.arange(12, dtype=np.intp))
338-
tm.assert_index_equal(res, idx)
353+
arr, res = obj.factorize()
354+
tm.assert_numpy_array_equal(arr, np.arange(12, dtype=np.intp))
355+
tm.assert_index_equal(res, idx)
356+
if index_or_series is Index:
357+
assert res.freq == idx.freq
339358

340359
@pytest.mark.parametrize(
341360
"arr, expected",

pandas/tests/indexes/timedeltas/test_timedelta.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,26 @@ def test_factorize(self):
7575
arr, idx = idx1.factorize()
7676
tm.assert_numpy_array_equal(arr, exp_arr)
7777
tm.assert_index_equal(idx, exp_idx)
78+
assert idx.freq == exp_idx.freq
7879

7980
arr, idx = idx1.factorize(sort=True)
8081
tm.assert_numpy_array_equal(arr, exp_arr)
8182
tm.assert_index_equal(idx, exp_idx)
83+
assert idx.freq == exp_idx.freq
8284

83-
# freq must be preserved
85+
def test_factorize_preserves_freq(self):
86+
# GH#38120 freq should be preserved
8487
idx3 = timedelta_range("1 day", periods=4, freq="s")
8588
exp_arr = np.array([0, 1, 2, 3], dtype=np.intp)
8689
arr, idx = idx3.factorize()
8790
tm.assert_numpy_array_equal(arr, exp_arr)
8891
tm.assert_index_equal(idx, idx3)
92+
assert idx.freq == idx3.freq
93+
94+
arr, idx = pd.factorize(idx3)
95+
tm.assert_numpy_array_equal(arr, exp_arr)
96+
tm.assert_index_equal(idx, idx3)
97+
assert idx.freq == idx3.freq
8998

9099
def test_sort_values(self):
91100

pandas/tests/indexing/multiindex/test_multiindex.py

+10
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,13 @@ def test_multiindex_get_loc_list_raises(self):
9191
msg = "unhashable type"
9292
with pytest.raises(TypeError, match=msg):
9393
idx.get_loc([])
94+
95+
def test_multiindex_with_datatime_level_preserves_freq(self):
96+
# https://github.com/pandas-dev/pandas/issues/35563
97+
idx = Index(range(2), name="A")
98+
dti = pd.date_range("2020-01-01", periods=7, freq="D", name="B")
99+
mi = MultiIndex.from_product([idx, dti])
100+
df = DataFrame(np.random.randn(14, 2), index=mi)
101+
result = df.loc[0].index
102+
tm.assert_index_equal(result, dti)
103+
assert result.freq == dti.freq

pandas/tests/window/common.py

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def get_result(obj, obj2=None):
1212
result = result.loc[(slice(None), 1), 5]
1313
result.index = result.index.droplevel(1)
1414
expected = get_result(frame[1], frame[5])
15-
expected.index = expected.index._with_freq(None)
1615
tm.assert_series_equal(result, expected, check_names=False)
1716

1817

0 commit comments

Comments
 (0)