Skip to content

Commit cbf80c2

Browse files
authored
BUG: melt losing ea dtype (#50316)
* BUG: melt losing ea dtype * Add asv * Fix mypy
1 parent 5f3c29e commit cbf80c2

File tree

4 files changed

+39
-15
lines changed

4 files changed

+39
-15
lines changed

asv_bench/benchmarks/reshape.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515

1616

1717
class Melt:
18-
def setup(self):
19-
self.df = DataFrame(np.random.randn(10000, 3), columns=["A", "B", "C"])
20-
self.df["id1"] = np.random.randint(0, 10, 10000)
21-
self.df["id2"] = np.random.randint(100, 1000, 10000)
18+
params = ["float64", "Float64"]
19+
param_names = ["dtype"]
20+
21+
def setup(self, dtype):
22+
self.df = DataFrame(
23+
np.random.randn(100_000, 3), columns=["A", "B", "C"], dtype=dtype
24+
)
25+
self.df["id1"] = pd.Series(np.random.randint(0, 10, 10000))
26+
self.df["id2"] = pd.Series(np.random.randint(100, 1000, 10000))
2227

23-
def time_melt_dataframe(self):
28+
def time_melt_dataframe(self, dtype):
2429
melt(self.df, id_vars=["id1", "id2"])
2530

2631

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,7 @@ Reshaping
908908
^^^^^^^^^
909909
- Bug in :meth:`DataFrame.pivot_table` raising ``TypeError`` for nullable dtype and ``margins=True`` (:issue:`48681`)
910910
- Bug in :meth:`DataFrame.unstack` and :meth:`Series.unstack` unstacking wrong level of :class:`MultiIndex` when :class:`MultiIndex` has mixed names (:issue:`48763`)
911+
- Bug in :meth:`DataFrame.melt` losing extension array dtype (:issue:`41570`)
911912
- Bug in :meth:`DataFrame.pivot` not respecting ``None`` as column name (:issue:`48293`)
912913
- Bug in :func:`join` when ``left_on`` or ``right_on`` is or includes a :class:`CategoricalIndex` incorrectly raising ``AttributeError`` (:issue:`48464`)
913914
- Bug in :meth:`DataFrame.pivot_table` raising ``ValueError`` with parameter ``margins=True`` when result is an empty :class:`DataFrame` (:issue:`49240`)

pandas/core/reshape/melt.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from pandas.core.tools.numeric import to_numeric
3434

3535
if TYPE_CHECKING:
36+
from pandas._typing import AnyArrayLike
37+
3638
from pandas import DataFrame
3739

3840

@@ -124,7 +126,7 @@ def melt(
124126
N, K = frame.shape
125127
K -= len(id_vars)
126128

127-
mdata = {}
129+
mdata: dict[Hashable, AnyArrayLike] = {}
128130
for col in id_vars:
129131
id_data = frame.pop(col)
130132
if is_extension_array_dtype(id_data):
@@ -141,17 +143,15 @@ def melt(
141143

142144
mcolumns = id_vars + var_name + [value_name]
143145

144-
# error: Incompatible types in assignment (expression has type "ndarray",
145-
# target has type "Series")
146-
mdata[value_name] = frame._values.ravel("F") # type: ignore[assignment]
146+
if frame.shape[1] > 0:
147+
mdata[value_name] = concat(
148+
[frame.iloc[:, i] for i in range(frame.shape[1])]
149+
).values
150+
else:
151+
mdata[value_name] = frame._values.ravel("F")
147152
for i, col in enumerate(var_name):
148153
# asanyarray will keep the columns as an Index
149-
150-
# error: Incompatible types in assignment (expression has type "ndarray", target
151-
# has type "Series")
152-
mdata[col] = np.asanyarray( # type: ignore[assignment]
153-
frame.columns._get_level_values(i)
154-
).repeat(N)
154+
mdata[col] = np.asanyarray(frame.columns._get_level_values(i)).repeat(N)
155155

156156
result = frame._constructor(mdata, columns=mcolumns)
157157

pandas/tests/reshape/test_melt.py

+18
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,24 @@ def test_melt_with_duplicate_columns(self):
420420
)
421421
tm.assert_frame_equal(result, expected)
422422

423+
@pytest.mark.parametrize("dtype", ["Int8", "Int64"])
424+
def test_melt_ea_dtype(self, dtype):
425+
# GH#41570
426+
df = DataFrame(
427+
{
428+
"a": pd.Series([1, 2], dtype="Int8"),
429+
"b": pd.Series([3, 4], dtype=dtype),
430+
}
431+
)
432+
result = df.melt()
433+
expected = DataFrame(
434+
{
435+
"variable": ["a", "a", "b", "b"],
436+
"value": pd.Series([1, 2, 3, 4], dtype=dtype),
437+
}
438+
)
439+
tm.assert_frame_equal(result, expected)
440+
423441

424442
class TestLreshape:
425443
def test_pairs(self):

0 commit comments

Comments
 (0)