Skip to content

Commit a53cf8d

Browse files
authored
BUG: DataFrame.stack losing EA dtypes with multi-columns and mixed dtypes (#51691)
* BUG: DataFrame.stack losing EA dtypes for mixed dtype dataframes * add test * whatsnew * avoid copy * update test
1 parent b3d18af commit a53cf8d

File tree

3 files changed

+43
-15
lines changed

3 files changed

+43
-15
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ Groupby/resample/rolling
202202

203203
Reshaping
204204
^^^^^^^^^
205+
- Bug in :meth:`DataFrame.stack` losing extension dtypes when columns is a :class:`MultiIndex` and frame contains mixed dtypes (:issue:`45740`)
205206
- Bug in :meth:`DataFrame.transpose` inferring dtype for object column (:issue:`51546`)
206207
-
207208

pandas/core/reshape/reshape.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from pandas.util._decorators import cache_readonly
1515
from pandas.util._exceptions import find_stack_level
1616

17-
from pandas.core.dtypes.cast import maybe_promote
17+
from pandas.core.dtypes.cast import (
18+
find_common_type,
19+
maybe_promote,
20+
)
1821
from pandas.core.dtypes.common import (
1922
ensure_platform_int,
2023
is_1d_only_ea_dtype,
@@ -746,25 +749,19 @@ def _convert_level_number(level_num: int, columns: Index):
746749
chunk.columns = level_vals_nan.take(chunk.columns.codes[-1])
747750
value_slice = chunk.reindex(columns=level_vals_used).values
748751
else:
749-
if frame._is_homogeneous_type and is_extension_array_dtype(
750-
frame.dtypes.iloc[0]
751-
):
752+
subset = this.iloc[:, loc]
753+
dtype = find_common_type(subset.dtypes.tolist())
754+
if is_extension_array_dtype(dtype):
752755
# TODO(EA2D): won't need special case, can go through .values
753756
# paths below (might change to ._values)
754-
dtype = this[this.columns[loc]].dtypes.iloc[0]
755-
subset = this[this.columns[loc]]
756-
757757
value_slice = dtype.construct_array_type()._concat_same_type(
758-
[x._values for _, x in subset.items()]
758+
[x._values.astype(dtype, copy=False) for _, x in subset.items()]
759759
)
760760
N, K = subset.shape
761761
idx = np.arange(N * K).reshape(K, N).T.ravel()
762762
value_slice = value_slice.take(idx)
763-
764-
elif frame._is_mixed_type:
765-
value_slice = this[this.columns[loc]].values
766763
else:
767-
value_slice = this.values[:, loc]
764+
value_slice = subset.values
768765

769766
if value_slice.ndim > 1:
770767
# i.e. not extension

pandas/tests/frame/test_stack_unstack.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,27 @@ def test_stack_multi_columns_non_unique_index(self, index, columns):
11461146
expected_codes = np.asarray(new_index.codes)
11471147
tm.assert_numpy_array_equal(stacked_codes, expected_codes)
11481148

1149+
@pytest.mark.parametrize(
1150+
"vals1, vals2, dtype1, dtype2, expected_dtype",
1151+
[
1152+
([1, 2], [3.0, 4.0], "Int64", "Float64", "Float64"),
1153+
([1, 2], ["foo", "bar"], "Int64", "string", "object"),
1154+
],
1155+
)
1156+
def test_stack_multi_columns_mixed_extension_types(
1157+
self, vals1, vals2, dtype1, dtype2, expected_dtype
1158+
):
1159+
# GH45740
1160+
df = DataFrame(
1161+
{
1162+
("A", 1): Series(vals1, dtype=dtype1),
1163+
("A", 2): Series(vals2, dtype=dtype2),
1164+
}
1165+
)
1166+
result = df.stack()
1167+
expected = df.astype(object).stack().astype(expected_dtype)
1168+
tm.assert_frame_equal(result, expected)
1169+
11491170
@pytest.mark.parametrize("level", [0, 1])
11501171
def test_unstack_mixed_extension_types(self, level):
11511172
index = MultiIndex.from_tuples([("A", 0), ("A", 1), ("B", 1)], names=["a", "b"])
@@ -2181,9 +2202,18 @@ def test_stack_nullable_dtype(self):
21812202
df[df.columns[0]] = df[df.columns[0]].astype(pd.Float64Dtype())
21822203
result = df.stack("station")
21832204

2184-
# TODO(EA2D): we get object dtype because DataFrame.values can't
2185-
# be an EA
2186-
expected = df.astype(object).stack("station")
2205+
expected = DataFrame(
2206+
{
2207+
"r": pd.array(
2208+
[50.0, 10.0, 10.0, 9.0, 305.0, 111.0], dtype=pd.Float64Dtype()
2209+
),
2210+
"t_mean": pd.array(
2211+
[226, 215, 215, 220, 232, 220], dtype=pd.Int64Dtype()
2212+
),
2213+
},
2214+
index=MultiIndex.from_product([index, columns.levels[0]]),
2215+
)
2216+
expected.columns.name = "element"
21872217
tm.assert_frame_equal(result, expected)
21882218

21892219
def test_unstack_mixed_level_names(self):

0 commit comments

Comments
 (0)