From 1f93c055781d14efa4ab74421ed0a788ad0c6d76 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 1 Nov 2020 18:57:58 -0800 Subject: [PATCH 1/4] CLN: use validate_fill_value in Index.where --- pandas/core/indexes/base.py | 18 +++++------------- pandas/core/indexes/datetimelike.py | 11 ++--------- pandas/tests/indexes/datetimelike.py | 3 +-- .../tests/indexes/datetimes/test_indexing.py | 16 +++++++++------- pandas/tests/indexes/period/test_indexing.py | 11 ++++++----- .../tests/indexes/timedeltas/test_indexing.py | 11 ++++++----- pandas/tests/indexing/test_coercion.py | 7 ++++--- 7 files changed, 33 insertions(+), 44 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 1938722225b98..ef1dd69402b37 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -40,7 +40,6 @@ ensure_int64, ensure_object, ensure_platform_int, - is_bool, is_bool_dtype, is_categorical_dtype, is_datetime64_any_dtype, @@ -4066,23 +4065,16 @@ def where(self, cond, other=None): if other is None: other = self._na_value - dtype = self.dtype values = self.values - if is_bool(other) or is_bool_dtype(other): - - # bools force casting - values = values.astype(object) - dtype = None + try: + self._validate_fill_value(other) + except (ValueError, TypeError): + return self.astype(object).where(cond, other) values = np.where(cond, values, other) - if self._is_numeric_dtype and np.any(isna(values)): - # We can't coerce to the numeric dtype of "self" (unless - # it's float) if there are NaN values in our output. - dtype = None - - return Index(values, dtype=dtype, name=self.name) + return Index(values, name=self.name) # construction helpers @final diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 751eafaa0d78e..2f535efbd0703 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -483,16 +483,9 @@ def isin(self, values, level=None): @Appender(Index.where.__doc__) def where(self, cond, other=None): - values = self._data._ndarray + other = self._data._validate_where_value(other) - try: - other = self._data._validate_where_value(other) - except (TypeError, ValueError) as err: - # Includes tzawareness mismatch and IncompatibleFrequencyError - oth = getattr(other, "dtype", other) - raise TypeError(f"Where requires matching dtype, not {oth}") from err - - result = np.where(cond, values, other) + result = np.where(cond, self._data._ndarray, other) arr = self._data._from_backing_data(result) return type(self)._simple_new(arr, name=self.name) diff --git a/pandas/tests/indexes/datetimelike.py b/pandas/tests/indexes/datetimelike.py index be8ca61f1a730..6f078237e3a97 100644 --- a/pandas/tests/indexes/datetimelike.py +++ b/pandas/tests/indexes/datetimelike.py @@ -143,10 +143,9 @@ def test_where_cast_str(self): result = index.where(mask, [str(index[0])]) tm.assert_index_equal(result, expected) - msg = "Where requires matching dtype, not foo" + msg = "value should be a '.*', 'NaT', or array of those" with pytest.raises(TypeError, match=msg): index.where(mask, "foo") - msg = r"Where requires matching dtype, not \['foo'\]" with pytest.raises(TypeError, match=msg): index.where(mask, ["foo"]) diff --git a/pandas/tests/indexes/datetimes/test_indexing.py b/pandas/tests/indexes/datetimes/test_indexing.py index 4e46eb126894b..fe92aefc0d708 100644 --- a/pandas/tests/indexes/datetimes/test_indexing.py +++ b/pandas/tests/indexes/datetimes/test_indexing.py @@ -177,24 +177,26 @@ def test_where_invalid_dtypes(self): i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist()) - with pytest.raises(TypeError, match="Where requires matching dtype"): + msg = "value should be a 'Timestamp', 'NaT', or array of those. Got" + msg2 = "Cannot compare tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg2): # passing tz-naive ndarray to tzaware DTI dti.where(notna(i2), i2.values) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg2): # passing tz-aware DTI to tznaive DTI dti.tz_localize(None).where(notna(i2), i2) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): dti.where(notna(i2), i2.tz_localize(None).to_period("D")) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): dti.where(notna(i2), i2.asi8.view("timedelta64[ns]")) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): dti.where(notna(i2), i2.asi8) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): # non-matching scalar dti.where(notna(i2), pd.Timedelta(days=4)) @@ -203,7 +205,7 @@ def test_where_mismatched_nat(self, tz_aware_fixture): dti = pd.date_range("2013-01-01", periods=3, tz=tz) cond = np.array([True, False, True]) - msg = "Where requires matching dtype" + msg = "value should be a 'Timestamp', 'NaT', or array of those. Got" with pytest.raises(TypeError, match=msg): # wrong-dtyped NaT dti.where(cond, np.timedelta64("NaT", "ns")) diff --git a/pandas/tests/indexes/period/test_indexing.py b/pandas/tests/indexes/period/test_indexing.py index b6d3c36f1682c..19dfa9137cc5c 100644 --- a/pandas/tests/indexes/period/test_indexing.py +++ b/pandas/tests/indexes/period/test_indexing.py @@ -545,16 +545,17 @@ def test_where_invalid_dtypes(self): i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D") - with pytest.raises(TypeError, match="Where requires matching dtype"): + msg = "value should be a 'Period', 'NaT', or array of those" + with pytest.raises(TypeError, match=msg): pi.where(notna(i2), i2.asi8) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): pi.where(notna(i2), i2.asi8.view("timedelta64[ns]")) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): pi.where(notna(i2), i2.to_timestamp("S")) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): # non-matching scalar pi.where(notna(i2), Timedelta(days=4)) @@ -562,7 +563,7 @@ def test_where_mismatched_nat(self): pi = period_range("20130101", periods=5, freq="D") cond = np.array([True, False, True, True, False]) - msg = "Where requires matching dtype" + msg = "value should be a 'Period', 'NaT', or array of those" with pytest.raises(TypeError, match=msg): # wrong-dtyped NaT pi.where(cond, np.timedelta64("NaT", "ns")) diff --git a/pandas/tests/indexes/timedeltas/test_indexing.py b/pandas/tests/indexes/timedeltas/test_indexing.py index 396a676b97a1b..37aa9653550fb 100644 --- a/pandas/tests/indexes/timedeltas/test_indexing.py +++ b/pandas/tests/indexes/timedeltas/test_indexing.py @@ -150,16 +150,17 @@ def test_where_invalid_dtypes(self): i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist()) - with pytest.raises(TypeError, match="Where requires matching dtype"): + msg = "value should be a 'Timedelta', 'NaT', or array of those" + with pytest.raises(TypeError, match=msg): tdi.where(notna(i2), i2.asi8) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): tdi.where(notna(i2), i2 + pd.Timestamp.now()) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D")) - with pytest.raises(TypeError, match="Where requires matching dtype"): + with pytest.raises(TypeError, match=msg): # non-matching scalar tdi.where(notna(i2), pd.Timestamp.now()) @@ -167,7 +168,7 @@ def test_where_mismatched_nat(self): tdi = timedelta_range("1 day", periods=3, freq="D", name="idx") cond = np.array([True, False, False]) - msg = "Where requires matching dtype" + msg = "value should be a 'Timedelta', 'NaT', or array of those" with pytest.raises(TypeError, match=msg): # wrong-dtyped NaT tdi.where(cond, np.datetime64("NaT", "ns")) diff --git a/pandas/tests/indexing/test_coercion.py b/pandas/tests/indexing/test_coercion.py index 436b2aa838b08..fd6f6fbc6a4ba 100644 --- a/pandas/tests/indexing/test_coercion.py +++ b/pandas/tests/indexing/test_coercion.py @@ -780,7 +780,7 @@ def test_where_index_timedelta64(self, value): result = tdi.where(cond, value) tm.assert_index_equal(result, expected) - msg = "Where requires matching dtype" + msg = "value should be a 'Timedelta', 'NaT', or array of thos" with pytest.raises(TypeError, match=msg): # wrong-dtyped NaT tdi.where(cond, np.datetime64("NaT", "ns")) @@ -804,11 +804,12 @@ def test_where_index_period(self): tm.assert_index_equal(result, expected) # Passing a mismatched scalar - msg = "Where requires matching dtype" + msg = "value should be a 'Period', 'NaT', or array of those" with pytest.raises(TypeError, match=msg): pi.where(cond, pd.Timedelta(days=4)) - with pytest.raises(TypeError, match=msg): + msg = r"Input has different freq=D from PeriodArray\(freq=Q-DEC\)" + with pytest.raises(ValueError, match=msg): pi.where(cond, pd.Period("2020-04-21", "D")) From c692e68f947d84ef78a19f9e841bc7afbd4964d1 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 2 Nov 2020 07:43:39 -0800 Subject: [PATCH 2/4] BUG: fix Index.where casting ints to str --- pandas/core/indexes/numeric.py | 2 ++ pandas/tests/indexes/base_class/test_where.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 pandas/tests/indexes/base_class/test_where.py diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index d6f571360b457..9eb8a8b719d41 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -121,6 +121,8 @@ def _validate_fill_value(self, value): # force conversion to object # so we don't lose the bools raise TypeError + if isinstance(value, str): + raise TypeError return value diff --git a/pandas/tests/indexes/base_class/test_where.py b/pandas/tests/indexes/base_class/test_where.py new file mode 100644 index 0000000000000..0c8969735e14e --- /dev/null +++ b/pandas/tests/indexes/base_class/test_where.py @@ -0,0 +1,13 @@ +import numpy as np + +from pandas import Index +import pandas._testing as tm + + +class TestWhere: + def test_where_intlike_str_doesnt_cast_ints(self): + idx = Index(range(3)) + mask = np.array([True, False, True]) + res = idx.where(mask, "2") + expected = Index([0, "2", 2]) + tm.assert_index_equal(res, expected) From 7698583b04a02d6d33e356c5726eabb2ca479024 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 2 Nov 2020 14:54:57 -0800 Subject: [PATCH 3/4] whatsnew --- doc/source/whatsnew/v1.2.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 8a092cb6e36db..45a95f6aeb2f6 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -456,6 +456,7 @@ Indexing - Bug in :meth:`Series.loc.__getitem__` with a non-unique :class:`MultiIndex` and an empty-list indexer (:issue:`13691`) - Bug in indexing on a :class:`Series` or :class:`DataFrame` with a :class:`MultiIndex` with a level named "0" (:issue:`37194`) - Bug in :meth:`Series.__getitem__` when using an unsigned integer array as an indexer giving incorrect results or segfaulting instead of raising ``KeyError`` (:issue:`37218`) +- Bug in :meth:`Index.where` incorrectly casting numeric values to strings (:issue:`37591`) Missing ^^^^^^^ From 9a156e4d095b3c703b10905872c2837bdf6cc968 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 2 Nov 2020 15:23:16 -0800 Subject: [PATCH 4/4] Post merge fixup --- pandas/core/indexes/datetimelike.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 51bff83a543ff..9e2ac6013cb43 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -482,7 +482,7 @@ def isin(self, values, level=None): @Appender(Index.where.__doc__) def where(self, cond, other=None): - other = self._data._validate_where_value(other) + other = self._data._validate_setitem_value(other) result = np.where(cond, self._data._ndarray, other) arr = self._data._from_backing_data(result)