Skip to content

BUG: retain extension dtypes in transpose #28048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
11 changes: 10 additions & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,16 @@ def transpose(self, *args, **kwargs):
new_values = new_values.copy()

nv.validate_transpose(tuple(), kwargs)
return self._constructor(new_values, **new_axes).__finalize__(self)
result = self._constructor(new_values, **new_axes).__finalize__(self)

if self.ndim == 2 and self._is_homogeneous_type and len(self.columns):
Copy link
Contributor

@jreback jreback Dec 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its worthwhile to make a method to encapsulate this maybe

def _homogeneous_dtype(self):
   # return the single dtype if homogeneous, None if not

if is_extension_array_dtype(self.dtypes.iloc[0]):
# Retain ExtensionArray dtypes through transpose;
# TODO: this can be made cleaner if/when (N, 1) EA are allowed
dtype = self.dtypes.iloc[0]
result = result.astype(dtype)

return result

def swapaxes(self, axis1, axis2, copy=True):
"""
Expand Down
12 changes: 4 additions & 8 deletions pandas/tests/arithmetic/test_datetime64.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,10 @@ def test_dt64arr_nat_comparison(self, tz_naive_fixture, box_with_array):
ts = pd.Timestamp.now(tz)
ser = pd.Series([ts, pd.NaT])

# FIXME: Can't transpose because that loses the tz dtype on
# the NaT column
obj = tm.box_expected(ser, box, transpose=False)
obj = tm.box_expected(ser, box)

expected = pd.Series([True, False], dtype=np.bool_)
expected = tm.box_expected(expected, xbox, transpose=False)
expected = tm.box_expected(expected, xbox)

result = obj == ts
tm.assert_equal(result, expected)
Expand Down Expand Up @@ -842,10 +840,8 @@ def test_dt64arr_add_sub_td64_nat(self, box_with_array, tz_naive_fixture):
other = np.timedelta64("NaT")
expected = pd.DatetimeIndex(["NaT"] * 9, tz=tz)

# FIXME: fails with transpose=True due to tz-aware DataFrame
# transpose bug
obj = tm.box_expected(dti, box_with_array, transpose=False)
expected = tm.box_expected(expected, box_with_array, transpose=False)
obj = tm.box_expected(dti, box_with_array)
expected = tm.box_expected(expected, box_with_array)

result = obj + other
tm.assert_equal(result, expected)
Expand Down
26 changes: 10 additions & 16 deletions pandas/tests/arithmetic/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,18 +755,16 @@ def test_pi_sub_isub_offset(self):
rng -= pd.offsets.MonthEnd(5)
tm.assert_index_equal(rng, expected)

def test_pi_add_offset_n_gt1(self, box_transpose_fail):
def test_pi_add_offset_n_gt1(self, box):
# GH#23215
# add offset to PeriodIndex with freq.n > 1
box, transpose = box_transpose_fail

per = pd.Period("2016-01", freq="2M")
pi = pd.PeriodIndex([per])

expected = pd.PeriodIndex(["2016-03"], freq="2M")

pi = tm.box_expected(pi, box, transpose=transpose)
expected = tm.box_expected(expected, box, transpose=transpose)
pi = tm.box_expected(pi, box)
expected = tm.box_expected(expected, box)

result = pi + per.freq
tm.assert_equal(result, expected)
Expand All @@ -780,9 +778,8 @@ def test_pi_add_offset_n_gt1_not_divisible(self, box_with_array):
pi = pd.PeriodIndex(["2016-01"], freq="2M")
expected = pd.PeriodIndex(["2016-04"], freq="2M")

# FIXME: with transposing these tests fail
pi = tm.box_expected(pi, box_with_array, transpose=False)
expected = tm.box_expected(expected, box_with_array, transpose=False)
pi = tm.box_expected(pi, box_with_array)
expected = tm.box_expected(expected, box_with_array)

result = pi + to_offset("3M")
tm.assert_equal(result, expected)
Expand Down Expand Up @@ -984,16 +981,15 @@ def test_pi_add_sub_timedeltalike_freq_mismatch_monthly(self, mismatched_freq):
with pytest.raises(IncompatibleFrequency, match=msg):
rng -= other

def test_parr_add_sub_td64_nat(self, box_transpose_fail):
def test_parr_add_sub_td64_nat(self, box):
# GH#23320 special handling for timedelta64("NaT")
box, transpose = box_transpose_fail

pi = pd.period_range("1994-04-01", periods=9, freq="19D")
other = np.timedelta64("NaT")
expected = pd.PeriodIndex(["NaT"] * 9, freq="19D")

obj = tm.box_expected(pi, box, transpose=transpose)
expected = tm.box_expected(expected, box, transpose=transpose)
obj = tm.box_expected(pi, box)
expected = tm.box_expected(expected, box)

result = obj + other
tm.assert_equal(result, expected)
Expand All @@ -1011,10 +1007,8 @@ def test_parr_add_sub_td64_nat(self, box_transpose_fail):
TimedeltaArray._from_sequence(["NaT"] * 9),
],
)
def test_parr_add_sub_tdt64_nat_array(self, box_df_fail, other):
# FIXME: DataFrame fails because when when operating column-wise
# timedelta64 entries become NaT and are treated like datetimes
box = box_df_fail
def test_parr_add_sub_tdt64_nat_array(self, box_with_array, other):
box = box_with_array

pi = pd.period_range("1994-04-01", periods=9, freq="19D")
expected = pd.PeriodIndex(["NaT"] * 9, freq="19D")
Expand Down
16 changes: 16 additions & 0 deletions pandas/tests/frame/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,22 @@ def test_no_warning(self, all_arithmetic_operators):


class TestTranspose:
@pytest.mark.parametrize(
"ser",
[
pd.date_range("2016-04-05 04:30", periods=3, tz="UTC"),
pd.period_range("1994", freq="A", periods=3),
pd.period_range("1969", freq="9s", periods=1),
pd.date_range("2016-04-05 04:30", periods=3).astype("category"),
pd.date_range("2016-04-05 04:30", periods=3, tz="UTC").astype("category"),
],
)
def test_transpose_retains_extension_dtype(self, ser):
# case with more than 1 column, must have same dtype
df = pd.DataFrame({"a": ser, "b": ser})
result = df.T
assert (result.dtypes == ser.dtype).all()

def test_transpose_tzaware_1col_single_tz(self):
# GH#26825
dti = pd.date_range("2016-04-05 04:30", periods=3, tz="UTC")
Expand Down