diff --git a/doc/source/whatsnew/v0.16.1.txt b/doc/source/whatsnew/v0.16.1.txt index 8c49e2780ed06..3949ad4394fb1 100644 --- a/doc/source/whatsnew/v0.16.1.txt +++ b/doc/source/whatsnew/v0.16.1.txt @@ -79,3 +79,4 @@ Bug Fixes - Bug in ``Series.quantile`` on empty Series of type ``Datetime`` or ``Timedelta`` (:issue:`9675`) - Bug in ``where`` causing incorrect results when upcasting was required (:issue:`9731`) +- Bug in ``where`` when dtype of self is datetime64 or timedelta64, but dtype of other is not \ No newline at end of file diff --git a/pandas/core/generic.py b/pandas/core/generic.py index e05709d7a180f..e1a50f8b5200a 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -3323,7 +3323,8 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None, except ValueError: new_other = np.array(other) - if not (new_other == np.array(other)).all(): + matches = (new_other == np.array(other)) + if matches is False or not matches.all(): other = np.array(other) # we can't use our existing dtype diff --git a/pandas/core/internals.py b/pandas/core/internals.py index 7a16fb2b6b0d7..3418295adabf5 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -1324,13 +1324,11 @@ def _try_fill(self, value): return value def _try_coerce_args(self, values, other): - """ provide coercion to our input arguments - we are going to compare vs i8, so coerce to floats - repring NaT with np.nan so nans propagate - values is always ndarray like, other may not be """ + """ Coerce values and other to float64, with null values converted to + NaN. values is always ndarray-like, other may not be """ def masker(v): mask = isnull(v) - v = v.view('i8').astype('float64') + v = v.astype('float64') v[mask] = np.nan return v @@ -1342,6 +1340,8 @@ def masker(v): other = _coerce_scalar_to_timedelta_type(other, unit='s', box=False).item() if other == tslib.iNaT: other = np.nan + elif lib.isscalar(other): + other = np.float64(other) else: other = masker(other) @@ -1807,16 +1807,20 @@ def _try_operate(self, values): return values.view('i8') def _try_coerce_args(self, values, other): - """ provide coercion to our input arguments - we are going to compare vs i8, so coerce to integer - values is always ndarra like, other may not be """ + """ Coerce values and other to dtype 'i8'. NaN and NaT convert to + the smallest i8, and will correctly round-trip to NaT if converted + back in _try_coerce_result. values is always ndarray-like, other + may not be """ values = values.view('i8') + if is_null_datelike_scalar(other): other = tslib.iNaT elif isinstance(other, datetime): other = lib.Timestamp(other).asm8.view('i8') - else: + elif hasattr(other, 'dtype') and com.is_integer_dtype(other): other = other.view('i8') + else: + other = np.array(other, dtype='i8') return values, other diff --git a/pandas/tests/test_series.py b/pandas/tests/test_series.py index e140ffd97051c..25487a081571a 100644 --- a/pandas/tests/test_series.py +++ b/pandas/tests/test_series.py @@ -1855,6 +1855,48 @@ def test_where_dups(self): expected = Series([5,11,2,5,11,2],index=[0,1,2,0,1,2]) assert_series_equal(comb, expected) + def test_where_datetime(self): + s = Series(date_range('20130102', periods=2)) + expected = Series([10, 10], dtype='datetime64[ns]') + mask = np.array([False, False]) + + rs = s.where(mask, [10, 10]) + assert_series_equal(rs, expected) + + rs = s.where(mask, 10) + assert_series_equal(rs, expected) + + rs = s.where(mask, 10.0) + assert_series_equal(rs, expected) + + rs = s.where(mask, [10.0, 10.0]) + assert_series_equal(rs, expected) + + rs = s.where(mask, [10.0, np.nan]) + expected = Series([10, None], dtype='datetime64[ns]') + assert_series_equal(rs, expected) + + def test_where_timedelta(self): + s = Series([1, 2], dtype='timedelta64[ns]') + expected = Series([10, 10], dtype='timedelta64[ns]') + mask = np.array([False, False]) + + rs = s.where(mask, [10, 10]) + assert_series_equal(rs, expected) + + rs = s.where(mask, 10) + assert_series_equal(rs, expected) + + rs = s.where(mask, 10.0) + assert_series_equal(rs, expected) + + rs = s.where(mask, [10.0, 10.0]) + assert_series_equal(rs, expected) + + rs = s.where(mask, [10.0, np.nan]) + expected = Series([10, None], dtype='timedelta64[ns]') + assert_series_equal(rs, expected) + def test_mask(self): s = Series(np.random.randn(5)) cond = s > 0