Skip to content

Commit 7c35cdf

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
BUG: DataFrame[dt64].where downcasting (pandas-dev#45837)
* BUG: DataFrame[dt64].where downcasting * whatsnew * GH ref * punt on dt64 * mypy fixup
1 parent 61002e5 commit 7c35cdf

File tree

7 files changed

+112
-16
lines changed

7 files changed

+112
-16
lines changed

doc/source/whatsnew/v1.5.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ Indexing
318318
- Bug in setting a NA value (``None`` or ``np.nan``) into a :class:`Series` with int-based :class:`IntervalDtype` incorrectly casting to object dtype instead of a float-based :class:`IntervalDtype` (:issue:`45568`)
319319
- Bug in :meth:`Series.__setitem__` with a non-integer :class:`Index` when using an integer key to set a value that cannot be set inplace where a ``ValueError`` was raised instead of casting to a common dtype (:issue:`45070`)
320320
- Bug in :meth:`Series.__setitem__` when setting incompatible values into a ``PeriodDtype`` or ``IntervalDtype`` :class:`Series` raising when indexing with a boolean mask but coercing when indexing with otherwise-equivalent indexers; these now consistently coerce, along with :meth:`Series.mask` and :meth:`Series.where` (:issue:`45768`)
321+
- Bug in :meth:`DataFrame.where` with multiple columns with datetime-like dtypes failing to downcast results consistent with other dtypes (:issue:`45837`)
321322
- Bug in :meth:`Series.loc.__setitem__` and :meth:`Series.loc.__getitem__` not raising when using multiple keys without using a :class:`MultiIndex` (:issue:`13831`)
322323
- Bug when setting a value too large for a :class:`Series` dtype failing to coerce to a common type (:issue:`26049`, :issue:`32878`)
323324
- Bug in :meth:`loc.__setitem__` treating ``range`` keys as positional instead of label-based (:issue:`45479`)

pandas/core/arrays/boolean.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def coerce_to_array(
214214
raise TypeError("Need to pass bool-like values")
215215

216216
if mask is None and mask_values is None:
217-
mask = np.zeros(len(values), dtype=bool)
217+
mask = np.zeros(values.shape, dtype=bool)
218218
elif mask is None:
219219
mask = mask_values
220220
else:

pandas/core/dtypes/cast.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
262262
dtype = "int64"
263263
elif inferred_type == "datetime64":
264264
dtype = "datetime64[ns]"
265-
elif inferred_type == "timedelta64":
265+
elif inferred_type in ["timedelta", "timedelta64"]:
266266
dtype = "timedelta64[ns]"
267267

268268
# try to upcast here
@@ -290,6 +290,14 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
290290
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
291291
result = result.astype(dtype)
292292

293+
elif dtype.kind == "m" and result.dtype == _dtype_obj:
294+
# test_where_downcast_to_td64
295+
result = cast(np.ndarray, result)
296+
result = array_to_timedelta64(result)
297+
298+
elif dtype == "M8[ns]" and result.dtype == _dtype_obj:
299+
return np.asarray(maybe_cast_to_datetime(result, dtype=dtype))
300+
293301
return result
294302

295303

pandas/core/internals/blocks.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -1426,21 +1426,40 @@ def where(self, other, cond, _downcast="infer") -> list[Block]:
14261426
except (ValueError, TypeError) as err:
14271427
_catch_deprecated_value_error(err)
14281428

1429-
if is_interval_dtype(self.dtype):
1430-
# TestSetitemFloatIntervalWithIntIntervalValues
1431-
blk = self.coerce_to_target_dtype(orig_other)
1432-
nbs = blk.where(orig_other, orig_cond)
1433-
return self._maybe_downcast(nbs, downcast=_downcast)
1429+
if self.ndim == 1 or self.shape[0] == 1:
14341430

1435-
elif isinstance(self, NDArrayBackedExtensionBlock):
1436-
# NB: not (yet) the same as
1437-
# isinstance(values, NDArrayBackedExtensionArray)
1438-
blk = self.coerce_to_target_dtype(orig_other)
1439-
nbs = blk.where(orig_other, orig_cond)
1440-
return self._maybe_downcast(nbs, "infer")
1431+
if is_interval_dtype(self.dtype):
1432+
# TestSetitemFloatIntervalWithIntIntervalValues
1433+
blk = self.coerce_to_target_dtype(orig_other)
1434+
nbs = blk.where(orig_other, orig_cond)
1435+
return self._maybe_downcast(nbs, downcast=_downcast)
1436+
1437+
elif isinstance(self, NDArrayBackedExtensionBlock):
1438+
# NB: not (yet) the same as
1439+
# isinstance(values, NDArrayBackedExtensionArray)
1440+
blk = self.coerce_to_target_dtype(orig_other)
1441+
nbs = blk.where(orig_other, orig_cond)
1442+
return self._maybe_downcast(nbs, downcast=_downcast)
1443+
1444+
else:
1445+
raise
14411446

14421447
else:
1443-
raise
1448+
# Same pattern we use in Block.putmask
1449+
is_array = isinstance(orig_other, (np.ndarray, ExtensionArray))
1450+
1451+
res_blocks = []
1452+
nbs = self._split()
1453+
for i, nb in enumerate(nbs):
1454+
n = orig_other
1455+
if is_array:
1456+
# we have a different value per-column
1457+
n = orig_other[:, i : i + 1]
1458+
1459+
submask = orig_cond[:, i : i + 1]
1460+
rbs = nb.where(n, submask)
1461+
res_blocks.extend(rbs)
1462+
return res_blocks
14441463

14451464
nb = self.make_block_same_class(res_values)
14461465
return [nb]

pandas/tests/arrays/boolean/test_construction.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,11 @@ def test_coerce_to_array():
183183
values = np.array([True, False, True, False], dtype="bool")
184184
mask = np.array([False, False, False, True], dtype="bool")
185185

186+
# passing 2D values is OK as long as no mask
187+
coerce_to_array(values.reshape(1, -1))
188+
186189
with pytest.raises(ValueError, match="values.shape and mask.shape must match"):
187-
coerce_to_array(values.reshape(1, -1))
190+
coerce_to_array(values.reshape(1, -1), mask=mask)
188191

189192
with pytest.raises(ValueError, match="values.shape and mask.shape must match"):
190193
coerce_to_array(values, mask=mask.reshape(1, -1))

pandas/tests/dtypes/cast/test_downcast.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
77

8-
from pandas import Series
8+
from pandas import (
9+
Series,
10+
Timedelta,
11+
)
912
import pandas._testing as tm
1013

1114

@@ -34,6 +37,13 @@
3437
"int64",
3538
np.array([decimal.Decimal(0.0)]),
3639
),
40+
(
41+
# GH#45837
42+
np.array([Timedelta(days=1), Timedelta(days=2)], dtype=object),
43+
"infer",
44+
np.array([1, 2], dtype="m8[D]").astype("m8[ns]"),
45+
),
46+
# TODO: similar for dt64, dt64tz, Period, Interval?
3747
],
3848
)
3949
def test_downcast(arr, expected, dtype):

pandas/tests/frame/indexing/test_where.py

+55
Original file line numberDiff line numberDiff line change
@@ -965,3 +965,58 @@ def test_where_inplace_casting(data):
965965
df_copy = df.where(pd.notnull(df), None).copy()
966966
df.where(pd.notnull(df), None, inplace=True)
967967
tm.assert_equal(df, df_copy)
968+
969+
970+
def test_where_downcast_to_td64():
971+
ser = Series([1, 2, 3])
972+
973+
mask = np.array([False, False, False])
974+
975+
td = pd.Timedelta(days=1)
976+
977+
res = ser.where(mask, td)
978+
expected = Series([td, td, td], dtype="m8[ns]")
979+
tm.assert_series_equal(res, expected)
980+
981+
982+
def _check_where_equivalences(df, mask, other, expected):
983+
# similar to tests.series.indexing.test_setitem.SetitemCastingEquivalences
984+
# but with DataFrame in mind and less fleshed-out
985+
res = df.where(mask, other)
986+
tm.assert_frame_equal(res, expected)
987+
988+
res = df.mask(~mask, other)
989+
tm.assert_frame_equal(res, expected)
990+
991+
# Note: we cannot do the same with frame.mask(~mask, other, inplace=True)
992+
# bc that goes through Block.putmask which does *not* downcast.
993+
994+
995+
def test_where_dt64_2d():
996+
dti = date_range("2016-01-01", periods=6)
997+
dta = dti._data.reshape(3, 2)
998+
other = dta - dta[0, 0]
999+
1000+
df = DataFrame(dta, columns=["A", "B"])
1001+
1002+
mask = np.asarray(df.isna())
1003+
mask[:, 1] = True
1004+
1005+
# setting all of one column, none of the other
1006+
expected = DataFrame({"A": other[:, 0], "B": dta[:, 1]})
1007+
_check_where_equivalences(df, mask, other, expected)
1008+
1009+
# setting part of one column, none of the other
1010+
mask[1, 0] = True
1011+
expected = DataFrame(
1012+
{
1013+
"A": np.array([other[0, 0], dta[1, 0], other[2, 0]], dtype=object),
1014+
"B": dta[:, 1],
1015+
}
1016+
)
1017+
_check_where_equivalences(df, mask, other, expected)
1018+
1019+
# setting nothing in either column
1020+
mask[:] = True
1021+
expected = df
1022+
_check_where_equivalences(df, mask, other, expected)

0 commit comments

Comments
 (0)