Skip to content

Commit 5054b34

Browse files
TomAugspurgerAlexKirko
authored andcommitted
BUG: preserve EA dtype in transpose (pandas-dev#30091)
1 parent bdfdcfa commit 5054b34

File tree

8 files changed

+68
-84
lines changed

8 files changed

+68
-84
lines changed

doc/source/whatsnew/v1.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,7 @@ Reshaping
864864
- Bug where :meth:`DataFrame.equals` returned True incorrectly in some cases when two DataFrames had the same columns in different orders (:issue:`28839`)
865865
- Bug in :meth:`DataFrame.replace` that caused non-numeric replacer's dtype not respected (:issue:`26632`)
866866
- Bug in :func:`melt` where supplying mixed strings and numeric values for ``id_vars`` or ``value_vars`` would incorrectly raise a ``ValueError`` (:issue:`29718`)
867+
- Dtypes are now preserved when transposing a ``DataFrame`` where each column is the same extension dtype (:issue:`30091`)
867868
- Bug in :func:`merge_asof` merging on a tz-aware ``left_index`` and ``right_on`` a tz-aware column (:issue:`29864`)
868869
-
869870

pandas/core/frame.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -2485,7 +2485,7 @@ def memory_usage(self, index=True, deep=False):
24852485
)
24862486
return result
24872487

2488-
def transpose(self, *args, **kwargs):
2488+
def transpose(self, *args, copy: bool = False):
24892489
"""
24902490
Transpose index and columns.
24912491
@@ -2495,9 +2495,14 @@ def transpose(self, *args, **kwargs):
24952495
24962496
Parameters
24972497
----------
2498-
*args, **kwargs
2499-
Additional arguments and keywords have no effect but might be
2500-
accepted for compatibility with numpy.
2498+
*args : tuple, optional
2499+
Accepted for compatibility with NumPy.
2500+
copy : bool, default False
2501+
Whether to copy the data after transposing, even for DataFrames
2502+
with a single dtype.
2503+
2504+
Note that a copy is always required for mixed dtype DataFrames,
2505+
or for DataFrames with any extension types.
25012506
25022507
Returns
25032508
-------
@@ -2578,7 +2583,29 @@ def transpose(self, *args, **kwargs):
25782583
dtype: object
25792584
"""
25802585
nv.validate_transpose(args, dict())
2581-
return super().transpose(1, 0, **kwargs)
2586+
# construct the args
2587+
2588+
dtypes = list(self.dtypes)
2589+
if self._is_homogeneous_type and dtypes and is_extension_array_dtype(dtypes[0]):
2590+
# We have EAs with the same dtype. We can preserve that dtype in transpose.
2591+
dtype = dtypes[0]
2592+
arr_type = dtype.construct_array_type()
2593+
values = self.values
2594+
2595+
new_values = [arr_type._from_sequence(row, dtype=dtype) for row in values]
2596+
result = self._constructor(
2597+
dict(zip(self.index, new_values)), index=self.columns
2598+
)
2599+
2600+
else:
2601+
new_values = self.values.T
2602+
if copy:
2603+
new_values = new_values.copy()
2604+
result = self._constructor(
2605+
new_values, index=self.columns, columns=self.index
2606+
)
2607+
2608+
return result.__finalize__(self)
25822609

25832610
T = property(transpose)
25842611

pandas/core/generic.py

-44
Original file line numberDiff line numberDiff line change
@@ -643,50 +643,6 @@ def _set_axis(self, axis, labels):
643643
self._data.set_axis(axis, labels)
644644
self._clear_item_cache()
645645

646-
def transpose(self, *args, **kwargs):
647-
"""
648-
Permute the dimensions of the %(klass)s
649-
650-
Parameters
651-
----------
652-
args : %(args_transpose)s
653-
copy : bool, default False
654-
Make a copy of the underlying data. Mixed-dtype data will
655-
always result in a copy
656-
**kwargs
657-
Additional keyword arguments will be passed to the function.
658-
659-
Returns
660-
-------
661-
y : same as input
662-
663-
Examples
664-
--------
665-
>>> p.transpose(2, 0, 1)
666-
>>> p.transpose(2, 0, 1, copy=True)
667-
"""
668-
669-
# construct the args
670-
axes, kwargs = self._construct_axes_from_arguments(
671-
args, kwargs, require_all=True
672-
)
673-
axes_names = tuple(self._get_axis_name(axes[a]) for a in self._AXIS_ORDERS)
674-
axes_numbers = tuple(self._get_axis_number(axes[a]) for a in self._AXIS_ORDERS)
675-
676-
# we must have unique axes
677-
if len(axes) != len(set(axes)):
678-
raise ValueError(f"Must specify {self._AXIS_LEN} unique axes")
679-
680-
new_axes = self._construct_axes_dict_from(
681-
self, [self._get_axis(x) for x in axes_names]
682-
)
683-
new_values = self.values.transpose(axes_numbers)
684-
if kwargs.pop("copy", None) or (len(args) and args[-1]):
685-
new_values = new_values.copy()
686-
687-
nv.validate_transpose(tuple(), kwargs)
688-
return self._constructor(new_values, **new_axes).__finalize__(self)
689-
690646
def swapaxes(self, axis1, axis2, copy=True):
691647
"""
692648
Interchange axes and swap values axes appropriately.

pandas/tests/arithmetic/conftest.py

-19
Original file line numberDiff line numberDiff line change
@@ -235,25 +235,6 @@ def box_df_fail(request):
235235
return request.param
236236

237237

238-
@pytest.fixture(
239-
params=[
240-
(pd.Index, False),
241-
(pd.Series, False),
242-
(pd.DataFrame, False),
243-
pytest.param((pd.DataFrame, True), marks=pytest.mark.xfail),
244-
(tm.to_array, False),
245-
],
246-
ids=id_func,
247-
)
248-
def box_transpose_fail(request):
249-
"""
250-
Fixture similar to `box` but testing both transpose cases for DataFrame,
251-
with the transpose=True case xfailed.
252-
"""
253-
# GH#23620
254-
return request.param
255-
256-
257238
@pytest.fixture(params=[pd.Index, pd.Series, pd.DataFrame, tm.to_array], ids=id_func)
258239
def box_with_array(request):
259240
"""

pandas/tests/arithmetic/test_period.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -753,18 +753,18 @@ def test_pi_sub_isub_offset(self):
753753
rng -= pd.offsets.MonthEnd(5)
754754
tm.assert_index_equal(rng, expected)
755755

756-
def test_pi_add_offset_n_gt1(self, box_transpose_fail):
756+
@pytest.mark.parametrize("transpose", [True, False])
757+
def test_pi_add_offset_n_gt1(self, box_with_array, transpose):
757758
# GH#23215
758759
# add offset to PeriodIndex with freq.n > 1
759-
box, transpose = box_transpose_fail
760760

761761
per = pd.Period("2016-01", freq="2M")
762762
pi = pd.PeriodIndex([per])
763763

764764
expected = pd.PeriodIndex(["2016-03"], freq="2M")
765765

766-
pi = tm.box_expected(pi, box, transpose=transpose)
767-
expected = tm.box_expected(expected, box, transpose=transpose)
766+
pi = tm.box_expected(pi, box_with_array, transpose=transpose)
767+
expected = tm.box_expected(expected, box_with_array, transpose=transpose)
768768

769769
result = pi + per.freq
770770
tm.assert_equal(result, expected)
@@ -982,16 +982,15 @@ def test_pi_add_sub_timedeltalike_freq_mismatch_monthly(self, mismatched_freq):
982982
with pytest.raises(IncompatibleFrequency, match=msg):
983983
rng -= other
984984

985-
def test_parr_add_sub_td64_nat(self, box_transpose_fail):
985+
@pytest.mark.parametrize("transpose", [True, False])
986+
def test_parr_add_sub_td64_nat(self, box_with_array, transpose):
986987
# GH#23320 special handling for timedelta64("NaT")
987-
box, transpose = box_transpose_fail
988-
989988
pi = pd.period_range("1994-04-01", periods=9, freq="19D")
990989
other = np.timedelta64("NaT")
991990
expected = pd.PeriodIndex(["NaT"] * 9, freq="19D")
992991

993-
obj = tm.box_expected(pi, box, transpose=transpose)
994-
expected = tm.box_expected(expected, box, transpose=transpose)
992+
obj = tm.box_expected(pi, box_with_array, transpose=transpose)
993+
expected = tm.box_expected(expected, box_with_array, transpose=transpose)
995994

996995
result = obj + other
997996
tm.assert_equal(result, expected)
@@ -1009,16 +1008,12 @@ def test_parr_add_sub_td64_nat(self, box_transpose_fail):
10091008
TimedeltaArray._from_sequence(["NaT"] * 9),
10101009
],
10111010
)
1012-
def test_parr_add_sub_tdt64_nat_array(self, box_df_fail, other):
1013-
# FIXME: DataFrame fails because when when operating column-wise
1014-
# timedelta64 entries become NaT and are treated like datetimes
1015-
box = box_df_fail
1016-
1011+
def test_parr_add_sub_tdt64_nat_array(self, box_with_array, other):
10171012
pi = pd.period_range("1994-04-01", periods=9, freq="19D")
10181013
expected = pd.PeriodIndex(["NaT"] * 9, freq="19D")
10191014

1020-
obj = tm.box_expected(pi, box)
1021-
expected = tm.box_expected(expected, box)
1015+
obj = tm.box_expected(pi, box_with_array)
1016+
expected = tm.box_expected(expected, box_with_array)
10221017

10231018
result = obj + other
10241019
tm.assert_equal(result, expected)

pandas/tests/extension/base/reshaping.py

+16
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,19 @@ def test_ravel(self, data):
295295
# Check that we have a view, not a copy
296296
result[0] = result[1]
297297
assert data[0] == data[1]
298+
299+
def test_transpose(self, data):
300+
df = pd.DataFrame({"A": data[:4], "B": data[:4]}, index=["a", "b", "c", "d"])
301+
result = df.T
302+
expected = pd.DataFrame(
303+
{
304+
"a": type(data)._from_sequence([data[0]] * 2, dtype=data.dtype),
305+
"b": type(data)._from_sequence([data[1]] * 2, dtype=data.dtype),
306+
"c": type(data)._from_sequence([data[2]] * 2, dtype=data.dtype),
307+
"d": type(data)._from_sequence([data[3]] * 2, dtype=data.dtype),
308+
},
309+
index=["A", "B"],
310+
)
311+
self.assert_frame_equal(result, expected)
312+
self.assert_frame_equal(np.transpose(np.transpose(df)), df)
313+
self.assert_frame_equal(np.transpose(np.transpose(df[["A"]])), df[["A"]])

pandas/tests/extension/json/test_json.py

+4
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ def test_unstack(self, data, index):
163163
# this matches otherwise
164164
return super().test_unstack(data, index)
165165

166+
@pytest.mark.xfail(reason="Inconsistent sizes.")
167+
def test_transpose(self, data):
168+
super().test_transpose(data)
169+
166170

167171
class TestGetitem(BaseJSON, base.BaseGetitemTests):
168172
pass

pandas/tests/extension/test_numpy.py

+4
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def test_merge_on_extension_array_duplicates(self, data):
332332
# Fails creating expected
333333
super().test_merge_on_extension_array_duplicates(data)
334334

335+
@skip_nested
336+
def test_transpose(self, data):
337+
super().test_transpose(data)
338+
335339

336340
class TestSetitem(BaseNumPyTests, base.BaseSetitemTests):
337341
@skip_nested

0 commit comments

Comments
 (0)