Skip to content

Commit 1f9b699

Browse files
evanpwjreback
authored andcommitted
BUG: where behaves badly when dtype of self is datetime or timedelta, and dtype of other is not (GH9804)
1 parent 17132ac commit 1f9b699

File tree

4 files changed

+59
-12
lines changed

4 files changed

+59
-12
lines changed

doc/source/whatsnew/v0.16.1.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ Bug Fixes
117117
- Bug in ``read_csv()`` interprets ``index_col=True`` as ``1`` (:issue:`9798`)
118118
- Bug in index equality comparisons using ``==`` failing on Index/MultiIndex type incompatibility (:issue:`9875`)
119119
- Bug in which ``SparseDataFrame`` could not take `nan` as a column name (:issue:`8822`)
120-
- Bug in ``Series.quantile`` on empty Series of type ``Datetime`` or ``Timedelta`` (:issue:`9675`)
121120
- Bug in ``to_msgpack`` and ``read_msgpack`` zlib and blosc compression support (:issue:`9783`)
122121
- Bug in unequal comparisons between a ``Series`` of dtype `"category"` and a scalar (e.g. ``Series(Categorical(list("abc"), categories=list("cba"), ordered=True)) > "b"``, which wouldn't use the order of the categories but use the lexicographical order. (:issue:`9848`)
123122

@@ -135,10 +134,11 @@ Bug Fixes
135134

136135

137136

137+
138138
- Bug in unequal comparisons between categorical data and a scalar, which was not in the categories (e.g. ``Series(Categorical(list("abc"), ordered=True)) > "d"``. This returned ``False`` for all elements, but now raises a ``TypeError``. Equality comparisons also now return ``False`` for ``==`` and ``True`` for ``!=``. (:issue:`9848`)
139139

140140
- Bug in DataFrame ``__setitem__`` when right hand side is a dictionary (:issue:`9874`)
141-
141+
- Bug in ``where`` when dtype is ``datetime64/timedelta64``, but dtype of other is not (:issue:`9804`)
142142
- Bug in ``MultiIndex.sortlevel()`` results in unicode level name breaks (:issue:`9875`)
143143
- Bug in which ``groupby.transform`` incorrectly enforced output dtypes to match input dtypes. (:issue:`9807`)
144144

pandas/core/generic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3335,7 +3335,8 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
33353335
except ValueError:
33363336
new_other = np.array(other)
33373337

3338-
if not (new_other == np.array(other)).all():
3338+
matches = (new_other == np.array(other))
3339+
if matches is False or not matches.all():
33393340
other = np.array(other)
33403341

33413342
# we can't use our existing dtype

pandas/core/internals.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1325,13 +1325,11 @@ def _try_fill(self, value):
13251325
return value
13261326

13271327
def _try_coerce_args(self, values, other):
1328-
""" provide coercion to our input arguments
1329-
we are going to compare vs i8, so coerce to floats
1330-
repring NaT with np.nan so nans propagate
1331-
values is always ndarray like, other may not be """
1328+
""" Coerce values and other to float64, with null values converted to
1329+
NaN. values is always ndarray-like, other may not be """
13321330
def masker(v):
13331331
mask = isnull(v)
1334-
v = v.view('i8').astype('float64')
1332+
v = v.astype('float64')
13351333
v[mask] = np.nan
13361334
return v
13371335

@@ -1343,6 +1341,8 @@ def masker(v):
13431341
other = _coerce_scalar_to_timedelta_type(other, unit='s', box=False).item()
13441342
if other == tslib.iNaT:
13451343
other = np.nan
1344+
elif lib.isscalar(other):
1345+
other = np.float64(other)
13461346
else:
13471347
other = masker(other)
13481348

@@ -1809,16 +1809,20 @@ def _try_operate(self, values):
18091809
return values.view('i8')
18101810

18111811
def _try_coerce_args(self, values, other):
1812-
""" provide coercion to our input arguments
1813-
we are going to compare vs i8, so coerce to integer
1814-
values is always ndarra like, other may not be """
1812+
""" Coerce values and other to dtype 'i8'. NaN and NaT convert to
1813+
the smallest i8, and will correctly round-trip to NaT if converted
1814+
back in _try_coerce_result. values is always ndarray-like, other
1815+
may not be """
18151816
values = values.view('i8')
1817+
18161818
if is_null_datelike_scalar(other):
18171819
other = tslib.iNaT
18181820
elif isinstance(other, datetime):
18191821
other = lib.Timestamp(other).asm8.view('i8')
1820-
else:
1822+
elif hasattr(other, 'dtype') and com.is_integer_dtype(other):
18211823
other = other.view('i8')
1824+
else:
1825+
other = np.array(other, dtype='i8')
18221826

18231827
return values, other
18241828

pandas/tests/test_series.py

+42
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,48 @@ def test_where_dups(self):
18591859
expected = Series([5,11,2,5,11,2],index=[0,1,2,0,1,2])
18601860
assert_series_equal(comb, expected)
18611861

1862+
def test_where_datetime(self):
1863+
s = Series(date_range('20130102', periods=2))
1864+
expected = Series([10, 10], dtype='datetime64[ns]')
1865+
mask = np.array([False, False])
1866+
1867+
rs = s.where(mask, [10, 10])
1868+
assert_series_equal(rs, expected)
1869+
1870+
rs = s.where(mask, 10)
1871+
assert_series_equal(rs, expected)
1872+
1873+
rs = s.where(mask, 10.0)
1874+
assert_series_equal(rs, expected)
1875+
1876+
rs = s.where(mask, [10.0, 10.0])
1877+
assert_series_equal(rs, expected)
1878+
1879+
rs = s.where(mask, [10.0, np.nan])
1880+
expected = Series([10, None], dtype='datetime64[ns]')
1881+
assert_series_equal(rs, expected)
1882+
1883+
def test_where_timedelta(self):
1884+
s = Series([1, 2], dtype='timedelta64[ns]')
1885+
expected = Series([10, 10], dtype='timedelta64[ns]')
1886+
mask = np.array([False, False])
1887+
1888+
rs = s.where(mask, [10, 10])
1889+
assert_series_equal(rs, expected)
1890+
1891+
rs = s.where(mask, 10)
1892+
assert_series_equal(rs, expected)
1893+
1894+
rs = s.where(mask, 10.0)
1895+
assert_series_equal(rs, expected)
1896+
1897+
rs = s.where(mask, [10.0, 10.0])
1898+
assert_series_equal(rs, expected)
1899+
1900+
rs = s.where(mask, [10.0, np.nan])
1901+
expected = Series([10, None], dtype='timedelta64[ns]')
1902+
assert_series_equal(rs, expected)
1903+
18621904
def test_mask(self):
18631905
# compare with tested results in test_where
18641906
s = Series(np.random.randn(5))

0 commit comments

Comments
 (0)