Skip to content

Commit 7b5957f

Browse files
authored
BUG: Set dtypes of new columns when stacking (#36991) (#40127)
1 parent c17b84a commit 7b5957f

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ Reshaping
591591
- Bug in :meth:`DataFrame.append` returning incorrect dtypes with combinations of ``datetime64`` and ``timedelta64`` dtypes (:issue:`39574`)
592592
- Bug in :meth:`DataFrame.pivot_table` returning a ``MultiIndex`` for a single value when operating on and empty ``DataFrame`` (:issue:`13483`)
593593
- Allow :class:`Index` to be passed to the :func:`numpy.all` function (:issue:`40180`)
594-
-
594+
- Bug in :meth:`DataFrame.stack` not preserving ``CategoricalDtype`` in a ``MultiIndex`` (:issue:`36991`)
595595

596596
Sparse
597597
^^^^^^

pandas/core/reshape/reshape.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,33 @@ def stack_multiple(frame, level, dropna=True):
600600
return result
601601

602602

603+
def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex:
604+
"""Creates a MultiIndex from the first N-1 levels of this MultiIndex."""
605+
if len(columns.levels) <= 2:
606+
return columns.levels[0]._rename(name=columns.names[0])
607+
608+
levs = [
609+
[lev[c] if c >= 0 else None for c in codes]
610+
for lev, codes in zip(columns.levels[:-1], columns.codes[:-1])
611+
]
612+
613+
# Remove duplicate tuples in the MultiIndex.
614+
tuples = zip(*levs)
615+
unique_tuples = (key for key, _ in itertools.groupby(tuples))
616+
new_levs = zip(*unique_tuples)
617+
618+
# The dtype of each level must be explicitly set to avoid inferring the wrong type.
619+
# See GH-36991.
620+
return MultiIndex.from_arrays(
621+
[
622+
# Not all indices can accept None values.
623+
Index(new_lev, dtype=lev.dtype) if None not in new_lev else new_lev
624+
for new_lev, lev in zip(new_levs, columns.levels)
625+
],
626+
names=columns.names[:-1],
627+
)
628+
629+
603630
def _stack_multi_columns(frame, level_num=-1, dropna=True):
604631
def _convert_level_number(level_num, columns):
605632
"""
@@ -634,20 +661,7 @@ def _convert_level_number(level_num, columns):
634661
level_to_sort = _convert_level_number(0, this.columns)
635662
this = this.sort_index(level=level_to_sort, axis=1)
636663

637-
# tuple list excluding level for grouping columns
638-
if len(frame.columns.levels) > 2:
639-
levs = []
640-
for lev, level_codes in zip(this.columns.levels[:-1], this.columns.codes[:-1]):
641-
if -1 in level_codes:
642-
lev = np.append(lev, None)
643-
levs.append(np.take(lev, level_codes))
644-
tuples = list(zip(*levs))
645-
unique_groups = [key for key, _ in itertools.groupby(tuples)]
646-
new_names = this.columns.names[:-1]
647-
new_columns = MultiIndex.from_tuples(unique_groups, names=new_names)
648-
else:
649-
new_columns = this.columns.levels[0]._rename(name=this.columns.names[0])
650-
unique_groups = new_columns
664+
new_columns = _stack_multi_column_index(this.columns)
651665

652666
# time to ravel the values
653667
new_data = {}
@@ -658,7 +672,7 @@ def _convert_level_number(level_num, columns):
658672
level_vals_used = np.take(level_vals_nan, level_codes)
659673
levsize = len(level_codes)
660674
drop_cols = []
661-
for key in unique_groups:
675+
for key in new_columns:
662676
try:
663677
loc = this.columns.get_loc(key)
664678
except KeyError:

pandas/tests/frame/test_stack_unstack.py

+21
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,27 @@ def test_stack_preserve_categorical_dtype(self, ordered, labels):
10651065

10661066
tm.assert_series_equal(result, expected)
10671067

1068+
@pytest.mark.parametrize("ordered", [False, True])
1069+
@pytest.mark.parametrize(
1070+
"labels,data",
1071+
[
1072+
(list("xyz"), [10, 11, 12, 13, 14, 15]),
1073+
(list("zyx"), [14, 15, 12, 13, 10, 11]),
1074+
],
1075+
)
1076+
def test_stack_multi_preserve_categorical_dtype(self, ordered, labels, data):
1077+
# GH-36991
1078+
cidx = pd.CategoricalIndex(labels, categories=sorted(labels), ordered=ordered)
1079+
cidx2 = pd.CategoricalIndex(["u", "v"], ordered=ordered)
1080+
midx = MultiIndex.from_product([cidx, cidx2])
1081+
df = DataFrame([sorted(data)], columns=midx)
1082+
result = df.stack([0, 1])
1083+
1084+
s_cidx = pd.CategoricalIndex(sorted(labels), ordered=ordered)
1085+
expected = Series(data, index=MultiIndex.from_product([[0], s_cidx, cidx2]))
1086+
1087+
tm.assert_series_equal(result, expected)
1088+
10681089
def test_stack_preserve_categorical_dtype_values(self):
10691090
# GH-23077
10701091
cat = pd.Categorical(["a", "a", "b", "c"])

0 commit comments

Comments
 (0)