From 954d74c21eea942c60456c23fa5d716445b99a4c Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 9 Jan 2022 21:59:45 -0800 Subject: [PATCH] BUG: DataFrame[dt64].quantile(axis=1) when empty returning f8 --- doc/source/whatsnew/v1.5.0.rst | 2 +- pandas/core/frame.py | 25 +++++++++++++---- pandas/tests/frame/methods/test_quantile.py | 31 +++++++++++++++------ 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index e723918ad8b4b..1590846218cc7 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -118,7 +118,7 @@ Categorical Datetimelike ^^^^^^^^^^^^ -- +- Bug in :meth:`DataFrame.quantile` with datetime-like dtypes and no rows incorrectly returning ``float64`` dtype instead of retaining datetime-like dtype (:issue:`41544`) - Timedelta diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 8f9294f2c0437..ad6cc6fd93790 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -123,6 +123,7 @@ is_object_dtype, is_scalar, is_sequence, + needs_i8_conversion, pandas_dtype, ) from pandas.core.dtypes.dtypes import ExtensionDtype @@ -10460,17 +10461,23 @@ def quantile( Name: 0.5, dtype: object """ validate_percentile(q) + axis = self._get_axis_number(axis) if not is_list_like(q): # BlockManager.quantile expects listlike, so we wrap and unwrap here - res = self.quantile( + res_df = self.quantile( [q], axis=axis, numeric_only=numeric_only, interpolation=interpolation ) - return res.iloc[0] + res = res_df.iloc[0] + if axis == 1 and len(self) == 0: + # GH#41544 try to get an appropriate dtype + dtype = find_common_type(list(self.dtypes)) + if needs_i8_conversion(dtype): + return res.astype(dtype) + return res q = Index(q, dtype=np.float64) data = self._get_numeric_data() if numeric_only else self - axis = self._get_axis_number(axis) if axis == 1: data = data.T @@ -10478,9 +10485,17 @@ def quantile( if len(data.columns) == 0: # GH#23925 _get_numeric_data may have dropped all columns cols = Index([], name=self.columns.name) + + dtype = np.float64 + if axis == 1: + # GH#41544 try to get an appropriate dtype + cdtype = find_common_type(list(self.dtypes)) + if needs_i8_conversion(cdtype): + dtype = cdtype + if is_list_like(q): - return self._constructor([], index=q, columns=cols) - return self._constructor_sliced([], index=cols, name=q, dtype=np.float64) + return self._constructor([], index=q, columns=cols, dtype=dtype) + return self._constructor_sliced([], index=cols, name=q, dtype=dtype) res = data._mgr.quantile(qs=q, axis=1, interpolation=interpolation) diff --git a/pandas/tests/frame/methods/test_quantile.py b/pandas/tests/frame/methods/test_quantile.py index 8ff1b211c0db1..d60fe7d4f316d 100644 --- a/pandas/tests/frame/methods/test_quantile.py +++ b/pandas/tests/frame/methods/test_quantile.py @@ -293,6 +293,28 @@ def test_quantile_datetime(self): expected = DataFrame(index=[0.5]) tm.assert_frame_equal(result, expected) + @pytest.mark.parametrize( + "dtype", + [ + "datetime64[ns]", + "datetime64[ns, US/Pacific]", + "timedelta64[ns]", + "Period[D]", + ], + ) + def test_quantile_dt64_empty(self, dtype): + # GH#41544 + df = DataFrame(columns=["a", "b"], dtype=dtype) + + res = df.quantile(0.5, axis=1, numeric_only=False) + expected = Series([], index=[], name=0.5, dtype=dtype) + tm.assert_series_equal(res, expected) + + # no columns in result, so no dtype preservation + res = df.quantile([0.5], axis=1, numeric_only=False) + expected = DataFrame(index=[0.5]) + tm.assert_frame_equal(res, expected) + def test_quantile_invalid(self, datetime_frame): msg = "percentiles should all be in the interval \\[0, 1\\]" for invalid in [-1, 2, [0.5, -1], [0.5, 2]]: @@ -722,14 +744,7 @@ def test_empty_numeric(self, dtype, expected_data, expected_index, axis): @pytest.mark.parametrize( "dtype, expected_data, expected_index, axis, expected_dtype", [ - pytest.param( - "datetime64[ns]", - [], - [], - 1, - "datetime64[ns]", - marks=pytest.mark.xfail(reason="#GH 41544"), - ), + ["datetime64[ns]", [], [], 1, "datetime64[ns]"], ["datetime64[ns]", [pd.NaT, pd.NaT], ["a", "b"], 0, "datetime64[ns]"], ], )