Skip to content

Commit 8e2bbee

Browse files
committed
BUG: retain extension dtypes in transpose
1 parent 7deda21 commit 8e2bbee

File tree

4 files changed

+40
-21
lines changed

4 files changed

+40
-21
lines changed

pandas/core/generic.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,18 @@ def transpose(self, *args, **kwargs):
725725
new_values = new_values.copy()
726726

727727
nv.validate_transpose(tuple(), kwargs)
728-
return self._constructor(new_values, **new_axes).__finalize__(self)
728+
result = self._constructor(new_values, **new_axes).__finalize__(self)
729+
730+
if len(self.columns) and (self.dtypes == self.dtypes.iloc[0]).all():
731+
# FIXME: self.dtypes[0] can fail in tests
732+
if is_extension_array_dtype(self.dtypes.iloc[0]):
733+
# Retain ExtensionArray dtypes through transpose;
734+
# TODO: this can be made cleaner if/when (N, 1) EA are allowed
735+
dtype = self.dtypes[0]
736+
for col in result.columns:
737+
result[col] = result[col].astype(dtype)
738+
739+
return result
729740

730741
def swapaxes(self, axis1, axis2, copy=True):
731742
"""

pandas/tests/arithmetic/test_datetime64.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,10 @@ def test_dt64arr_nat_comparison(self, tz_naive_fixture, box_with_array):
153153
ts = pd.Timestamp.now(tz)
154154
ser = pd.Series([ts, pd.NaT])
155155

156-
# FIXME: Can't transpose because that loses the tz dtype on
157-
# the NaT column
158-
obj = tm.box_expected(ser, box, transpose=False)
156+
obj = tm.box_expected(ser, box)
159157

160158
expected = pd.Series([True, False], dtype=np.bool_)
161-
expected = tm.box_expected(expected, xbox, transpose=False)
159+
expected = tm.box_expected(expected, xbox)
162160

163161
result = obj == ts
164162
tm.assert_equal(result, expected)
@@ -879,10 +877,8 @@ def test_dt64arr_add_sub_td64_nat(self, box_with_array, tz_naive_fixture):
879877
other = np.timedelta64("NaT")
880878
expected = pd.DatetimeIndex(["NaT"] * 9, tz=tz)
881879

882-
# FIXME: fails with transpose=True due to tz-aware DataFrame
883-
# transpose bug
884-
obj = tm.box_expected(dti, box_with_array, transpose=False)
885-
expected = tm.box_expected(expected, box_with_array, transpose=False)
880+
obj = tm.box_expected(dti, box_with_array)
881+
expected = tm.box_expected(expected, box_with_array)
886882

887883
result = obj + other
888884
tm.assert_equal(result, expected)

pandas/tests/arithmetic/test_period.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -755,18 +755,16 @@ def test_pi_sub_isub_offset(self):
755755
rng -= pd.offsets.MonthEnd(5)
756756
tm.assert_index_equal(rng, expected)
757757

758-
def test_pi_add_offset_n_gt1(self, box_transpose_fail):
758+
def test_pi_add_offset_n_gt1(self, box):
759759
# GH#23215
760760
# add offset to PeriodIndex with freq.n > 1
761-
box, transpose = box_transpose_fail
762-
763761
per = pd.Period("2016-01", freq="2M")
764762
pi = pd.PeriodIndex([per])
765763

766764
expected = pd.PeriodIndex(["2016-03"], freq="2M")
767765

768-
pi = tm.box_expected(pi, box, transpose=transpose)
769-
expected = tm.box_expected(expected, box, transpose=transpose)
766+
pi = tm.box_expected(pi, box)
767+
expected = tm.box_expected(expected, box)
770768

771769
result = pi + per.freq
772770
tm.assert_equal(result, expected)
@@ -780,9 +778,8 @@ def test_pi_add_offset_n_gt1_not_divisible(self, box_with_array):
780778
pi = pd.PeriodIndex(["2016-01"], freq="2M")
781779
expected = pd.PeriodIndex(["2016-04"], freq="2M")
782780

783-
# FIXME: with transposing these tests fail
784-
pi = tm.box_expected(pi, box_with_array, transpose=False)
785-
expected = tm.box_expected(expected, box_with_array, transpose=False)
781+
pi = tm.box_expected(pi, box_with_array)
782+
expected = tm.box_expected(expected, box_with_array)
786783

787784
result = pi + to_offset("3M")
788785
tm.assert_equal(result, expected)
@@ -984,16 +981,15 @@ def test_pi_add_sub_timedeltalike_freq_mismatch_monthly(self, mismatched_freq):
984981
with pytest.raises(IncompatibleFrequency, match=msg):
985982
rng -= other
986983

987-
def test_parr_add_sub_td64_nat(self, box_transpose_fail):
984+
def test_parr_add_sub_td64_nat(self, box):
988985
# GH#23320 special handling for timedelta64("NaT")
989-
box, transpose = box_transpose_fail
990986

991987
pi = pd.period_range("1994-04-01", periods=9, freq="19D")
992988
other = np.timedelta64("NaT")
993989
expected = pd.PeriodIndex(["NaT"] * 9, freq="19D")
994990

995-
obj = tm.box_expected(pi, box, transpose=transpose)
996-
expected = tm.box_expected(expected, box, transpose=transpose)
991+
obj = tm.box_expected(pi, box)
992+
expected = tm.box_expected(expected, box)
997993

998994
result = obj + other
999995
tm.assert_equal(result, expected)

pandas/tests/frame/test_operators.py

+16
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,22 @@ def test_no_warning(self, all_arithmetic_operators):
844844

845845

846846
class TestTranspose:
847+
@pytest.mark.parametrize(
848+
"ser",
849+
[
850+
pd.date_range("2016-04-05 04:30", periods=3, tz="UTC"),
851+
pd.period_range("1994", freq="A", periods=3),
852+
pd.period_range("1969", freq="9s", periods=1),
853+
pd.date_range("2016-04-05 04:30", periods=3).astype("category"),
854+
pd.date_range("2016-04-05 04:30", periods=3, tz="UTC").astype("category"),
855+
],
856+
)
857+
def test_transpose_retains_extension_dtype(self, ser):
858+
# case with more than 1 column, must have same dtype
859+
df = pd.DataFrame({"a": ser, "b": ser})
860+
result = df.T
861+
assert (result.dtypes == ser.dtype).all()
862+
847863
def test_transpose_tzaware_1col_single_tz(self):
848864
# GH#26825
849865
dti = pd.date_range("2016-04-05 04:30", periods=3, tz="UTC")

0 commit comments

Comments
 (0)