Skip to content

BUG: DataFrame.stack losing EA dtypes with multi-columns and mixed dtypes #51691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 9, 2023
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ Groupby/resample/rolling

Reshaping
^^^^^^^^^
- Bug in :meth:`DataFrame.stack` losing extension dtypes when columns is a :class:`MultiIndex` and frame contains mixed dtypes (:issue:`45740`)
- Bug in :meth:`DataFrame.transpose` inferring dtype for object column (:issue:`51546`)
-

Expand Down
21 changes: 9 additions & 12 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.cast import maybe_promote
from pandas.core.dtypes.cast import (
find_common_type,
maybe_promote,
)
from pandas.core.dtypes.common import (
ensure_platform_int,
is_1d_only_ea_dtype,
Expand Down Expand Up @@ -746,25 +749,19 @@ def _convert_level_number(level_num: int, columns: Index):
chunk.columns = level_vals_nan.take(chunk.columns.codes[-1])
value_slice = chunk.reindex(columns=level_vals_used).values
else:
if frame._is_homogeneous_type and is_extension_array_dtype(
frame.dtypes.iloc[0]
):
subset = this.iloc[:, loc]
dtype = find_common_type(subset.dtypes.tolist())
if is_extension_array_dtype(dtype):
# TODO(EA2D): won't need special case, can go through .values
# paths below (might change to ._values)
dtype = this[this.columns[loc]].dtypes.iloc[0]
subset = this[this.columns[loc]]

value_slice = dtype.construct_array_type()._concat_same_type(
[x._values for _, x in subset.items()]
[x._values.astype(dtype, copy=False) for _, x in subset.items()]
)
N, K = subset.shape
idx = np.arange(N * K).reshape(K, N).T.ravel()
value_slice = value_slice.take(idx)

elif frame._is_mixed_type:
value_slice = this[this.columns[loc]].values
else:
value_slice = this.values[:, loc]
value_slice = subset.values

if value_slice.ndim > 1:
# i.e. not extension
Expand Down
36 changes: 33 additions & 3 deletions pandas/tests/frame/test_stack_unstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,27 @@ def test_stack_multi_columns_non_unique_index(self, index, columns):
expected_codes = np.asarray(new_index.codes)
tm.assert_numpy_array_equal(stacked_codes, expected_codes)

@pytest.mark.parametrize(
"vals1, vals2, dtype1, dtype2, expected_dtype",
[
([1, 2], [3.0, 4.0], "Int64", "Float64", "Float64"),
([1, 2], ["foo", "bar"], "Int64", "string", "object"),
],
)
def test_stack_multi_columns_mixed_extension_types(
self, vals1, vals2, dtype1, dtype2, expected_dtype
):
# GH45740
df = DataFrame(
{
("A", 1): Series(vals1, dtype=dtype1),
("A", 2): Series(vals2, dtype=dtype2),
}
)
result = df.stack()
expected = df.astype(object).stack().astype(expected_dtype)
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("level", [0, 1])
def test_unstack_mixed_extension_types(self, level):
index = MultiIndex.from_tuples([("A", 0), ("A", 1), ("B", 1)], names=["a", "b"])
Expand Down Expand Up @@ -2181,9 +2202,18 @@ def test_stack_nullable_dtype(self):
df[df.columns[0]] = df[df.columns[0]].astype(pd.Float64Dtype())
result = df.stack("station")

# TODO(EA2D): we get object dtype because DataFrame.values can't
# be an EA
expected = df.astype(object).stack("station")
expected = DataFrame(
{
"r": pd.array(
[50.0, 10.0, 10.0, 9.0, 305.0, 111.0], dtype=pd.Float64Dtype()
),
"t_mean": pd.array(
[226, 215, 215, 220, 232, 220], dtype=pd.Int64Dtype()
),
},
index=MultiIndex.from_product([index, columns.levels[0]]),
)
expected.columns.name = "element"
tm.assert_frame_equal(result, expected)

def test_unstack_mixed_level_names(self):
Expand Down