Skip to content

Commit 00e0525

Browse files
TomAugspurgerPingviinituutti
authored andcommitted
Preserve EA dtype in DataFrame.stack (pandas-dev#23285)
1 parent 3c67ae3 commit 00e0525

File tree

7 files changed

+111
-6
lines changed

7 files changed

+111
-6
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
853853
- Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`)
854854
- :func:`ExtensionArray.isna` is allowed to return an ``ExtensionArray`` (:issue:`22325`).
855855
- Support for reduction operations such as ``sum``, ``mean`` via opt-in base class method override (:issue:`22762`)
856+
- :meth:`DataFrame.stack` no longer converts to object dtype for DataFrames where each column has the same extension dtype. The output Series will have the same dtype as the columns (:issue:`23077`).
856857
- :meth:`Series.unstack` and :meth:`DataFrame.unstack` no longer convert extension arrays to object-dtype ndarrays. Each column in the output ``DataFrame`` will now have the same dtype as the input (:issue:`23077`).
857858
- Bug when grouping :meth:`Dataframe.groupby()` and aggregating on ``ExtensionArray`` it was not returning the actual ``ExtensionArray`` dtype (:issue:`23227`).
858859

pandas/core/internals/blocks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
is_numeric_v_string_like, is_extension_type,
3636
is_extension_array_dtype,
3737
is_list_like,
38-
is_sparse,
3938
is_re,
4039
is_re_compilable,
40+
is_sparse,
4141
pandas_dtype)
4242
from pandas.core.dtypes.cast import (
4343
maybe_downcast_to_dtype,

pandas/core/reshape/reshape.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,9 @@ def factorize(index):
494494
if is_extension_array_dtype(dtype):
495495
arr = dtype.construct_array_type()
496496
new_values = arr._concat_same_type([
497-
col for _, col in frame.iteritems()
497+
col._values for _, col in frame.iteritems()
498498
])
499+
new_values = _reorder_for_extension_array_stack(new_values, N, K)
499500
else:
500501
# homogeneous, non-EA
501502
new_values = frame.values.ravel()
@@ -624,16 +625,32 @@ def _convert_level_number(level_num, columns):
624625
slice_len = loc.stop - loc.start
625626

626627
if slice_len != levsize:
627-
chunk = this.loc[:, this.columns[loc]]
628+
chunk = this[this.columns[loc]]
628629
chunk.columns = level_vals.take(chunk.columns.labels[-1])
629630
value_slice = chunk.reindex(columns=level_vals_used).values
630631
else:
631-
if frame._is_mixed_type:
632-
value_slice = this.loc[:, this.columns[loc]].values
632+
if (frame._is_homogeneous_type and
633+
is_extension_array_dtype(frame.dtypes.iloc[0])):
634+
dtype = this[this.columns[loc]].dtypes.iloc[0]
635+
subset = this[this.columns[loc]]
636+
637+
value_slice = dtype.construct_array_type()._concat_same_type(
638+
[x._values for _, x in subset.iteritems()]
639+
)
640+
N, K = this.shape
641+
idx = np.arange(N * K).reshape(K, N).T.ravel()
642+
value_slice = value_slice.take(idx)
643+
644+
elif frame._is_mixed_type:
645+
value_slice = this[this.columns[loc]].values
633646
else:
634647
value_slice = this.values[:, loc]
635648

636-
new_data[key] = value_slice.ravel()
649+
if value_slice.ndim > 1:
650+
# i.e. not extension
651+
value_slice = value_slice.ravel()
652+
653+
new_data[key] = value_slice
637654

638655
if len(drop_cols) > 0:
639656
new_columns = new_columns.difference(drop_cols)
@@ -971,3 +988,38 @@ def make_axis_dummies(frame, axis='minor', transform=None):
971988
values = values.take(labels, axis=0)
972989

973990
return DataFrame(values, columns=items, index=frame.index)
991+
992+
993+
def _reorder_for_extension_array_stack(arr, n_rows, n_columns):
994+
"""
995+
Re-orders the values when stacking multiple extension-arrays.
996+
997+
The indirect stacking method used for EAs requires a followup
998+
take to get the order correct.
999+
1000+
Parameters
1001+
----------
1002+
arr : ExtensionArray
1003+
n_rows, n_columns : int
1004+
The number of rows and columns in the original DataFrame.
1005+
1006+
Returns
1007+
-------
1008+
taken : ExtensionArray
1009+
The original `arr` with elements re-ordered appropriately
1010+
1011+
Examples
1012+
--------
1013+
>>> arr = np.array(['a', 'b', 'c', 'd', 'e', 'f'])
1014+
>>> _reorder_for_extension_array_stack(arr, 2, 3)
1015+
array(['a', 'c', 'e', 'b', 'd', 'f'], dtype='<U1')
1016+
1017+
>>> _reorder_for_extension_array_stack(arr, 3, 2)
1018+
array(['a', 'd', 'b', 'e', 'c', 'f'], dtype='<U1')
1019+
"""
1020+
# final take to get the order correct.
1021+
# idx is an indexer like
1022+
# [c0r0, c1r0, c2r0, ...,
1023+
# c0r1, c1r1, c2r1, ...]
1024+
idx = np.arange(n_rows * n_columns).reshape(n_columns, n_rows).T.ravel()
1025+
return arr.take(idx)

pandas/tests/extension/base/reshaping.py

+22
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,28 @@ def test_merge(self, data, na_value):
173173
dtype=data.dtype)})
174174
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])
175175

176+
@pytest.mark.parametrize("columns", [
177+
["A", "B"],
178+
pd.MultiIndex.from_tuples([('A', 'a'), ('A', 'b')],
179+
names=['outer', 'inner']),
180+
])
181+
def test_stack(self, data, columns):
182+
df = pd.DataFrame({"A": data[:5], "B": data[:5]})
183+
df.columns = columns
184+
result = df.stack()
185+
expected = df.astype(object).stack()
186+
# we need a second astype(object), in case the constructor inferred
187+
# object -> specialized, as is done for period.
188+
expected = expected.astype(object)
189+
190+
if isinstance(expected, pd.Series):
191+
assert result.dtype == df.iloc[:, 0].dtype
192+
else:
193+
assert all(result.dtypes == df.iloc[:, 0].dtype)
194+
195+
result = result.astype(object)
196+
self.assert_equal(result, expected)
197+
176198
@pytest.mark.parametrize("index", [
177199
# Two levels, uniform.
178200
pd.MultiIndex.from_product(([['A', 'B'], ['a', 'b']]),

pandas/tests/extension/json/test_json.py

+9
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,15 @@ def test_from_dtype(self, data):
139139

140140

141141
class TestReshaping(BaseJSON, base.BaseReshapingTests):
142+
143+
@pytest.mark.skip(reason="Different definitions of NA")
144+
def test_stack(self):
145+
"""
146+
The test does .astype(object).stack(). If we happen to have
147+
any missing values in `data`, then we'll end up with different
148+
rows since we consider `{}` NA, but `.astype(object)` doesn't.
149+
"""
150+
142151
@pytest.mark.xfail(reason="dict for NA", strict=True)
143152
def test_unstack(self, data, index):
144153
# The base test has NaN for the expected NA value.

pandas/tests/frame/test_reshape.py

+11
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,17 @@ def test_stack_preserve_categorical_dtype(self, ordered, labels):
874874

875875
tm.assert_series_equal(result, expected)
876876

877+
def test_stack_preserve_categorical_dtype_values(self):
878+
# GH-23077
879+
cat = pd.Categorical(['a', 'a', 'b', 'c'])
880+
df = pd.DataFrame({"A": cat, "B": cat})
881+
result = df.stack()
882+
index = pd.MultiIndex.from_product([[0, 1, 2, 3], ['A', 'B']])
883+
expected = pd.Series(pd.Categorical(['a', 'a', 'a', 'a',
884+
'b', 'b', 'c', 'c']),
885+
index=index)
886+
tm.assert_series_equal(result, expected)
887+
877888
@pytest.mark.parametrize('level', [0, 1])
878889
def test_unstack_mixed_extension_types(self, level):
879890
index = pd.MultiIndex.from_tuples([('A', 0), ('A', 1), ('B', 1)],

pandas/tests/sparse/frame/test_frame.py

+10
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,16 @@ def test_astype_bool(self):
736736
assert res['A'].dtype == SparseDtype(np.bool)
737737
assert res['B'].dtype == SparseDtype(np.bool)
738738

739+
def test_astype_object(self):
740+
# This may change in GH-23125
741+
df = pd.DataFrame({"A": SparseArray([0, 1]),
742+
"B": SparseArray([0, 1])})
743+
result = df.astype(object)
744+
dtype = SparseDtype(object, 0)
745+
expected = pd.DataFrame({"A": SparseArray([0, 1], dtype=dtype),
746+
"B": SparseArray([0, 1], dtype=dtype)})
747+
tm.assert_frame_equal(result, expected)
748+
739749
def test_fillna(self, float_frame_fill0, float_frame_fill0_dense):
740750
df = float_frame_fill0.reindex(lrange(5))
741751
dense = float_frame_fill0_dense.reindex(lrange(5))

0 commit comments

Comments
 (0)