From 90d571c54df53932ae0e2ad36ae84c3bc2ee7bd4 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 22 Jan 2021 13:35:41 -0800 Subject: [PATCH] REF: standardize get_indexer/get_indexer_non_unique --- pandas/core/indexes/base.py | 33 +++++++++++-------- pandas/tests/indexes/numeric/test_indexing.py | 31 +++++++++++++++++ pandas/tests/indexes/test_base.py | 11 +++++++ pandas/tests/series/methods/test_reindex.py | 5 ++- 4 files changed, 65 insertions(+), 15 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index e91a25a9e23e8..bb5f05147aefb 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3289,11 +3289,10 @@ def get_indexer( if not self._index_as_unique: raise InvalidIndexError(self._requires_unique_msg) - # Treat boolean labels passed to a numeric index as not found. Without - # this fix False and True would be treated as 0 and 1 respectively. - # (GH #16877) - if target.is_boolean() and self.is_numeric(): - return ensure_platform_int(np.repeat(-1, target.size)) + if not self._should_compare(target) and not is_interval_dtype(self.dtype): + # IntervalIndex get special treatment bc numeric scalars can be + # matched to Interval scalars + return self._get_indexer_non_comparable(target, method=method, unique=True) pself, ptarget = self._maybe_promote(target) if pself is not self or ptarget is not target: @@ -3310,8 +3309,9 @@ def _get_indexer( tolerance = self._convert_tolerance(tolerance, target) if not is_dtype_equal(self.dtype, target.dtype): - this = self.astype(object) - target = target.astype(object) + dtype = find_common_type([self.dtype, target.dtype]) + this = self.astype(dtype, copy=False) + target = target.astype(dtype, copy=False) return this.get_indexer( target, method=method, limit=limit, tolerance=tolerance ) @@ -5060,19 +5060,15 @@ def set_value(self, arr, key, value): def get_indexer_non_unique(self, target): target = ensure_index(target) - if target.is_boolean() and self.is_numeric(): - # Treat boolean labels passed to a numeric index as not found. Without - # this fix False and True would be treated as 0 and 1 respectively. - # (GH #16877) + if not self._should_compare(target) and not is_interval_dtype(self.dtype): + # IntervalIndex get special treatment bc numeric scalars can be + # matched to Interval scalars return self._get_indexer_non_comparable(target, method=None, unique=False) pself, ptarget = self._maybe_promote(target) if pself is not self or ptarget is not target: return pself.get_indexer_non_unique(ptarget) - if not self._should_compare(target): - return self._get_indexer_non_comparable(target, method=None, unique=False) - if not is_dtype_equal(self.dtype, target.dtype): # TODO: if object, could use infer_dtype to preempt costly # conversion if still non-comparable? @@ -5193,6 +5189,15 @@ def _should_compare(self, other: Index) -> bool: """ Check if `self == other` can ever have non-False entries. """ + + if (other.is_boolean() and self.is_numeric()) or ( + self.is_boolean() and other.is_numeric() + ): + # GH#16877 Treat boolean labels passed to a numeric index as not + # found. Without this fix False and True would be treated as 0 and 1 + # respectively. + return False + other = unpack_nested_dtype(other) dtype = other.dtype return self._is_comparable_dtype(dtype) or is_object_dtype(dtype) diff --git a/pandas/tests/indexes/numeric/test_indexing.py b/pandas/tests/indexes/numeric/test_indexing.py index f329a04612e33..7420cac2f9da4 100644 --- a/pandas/tests/indexes/numeric/test_indexing.py +++ b/pandas/tests/indexes/numeric/test_indexing.py @@ -82,6 +82,37 @@ def test_get_loc_missing_nan(self): class TestGetIndexer: + @pytest.mark.parametrize("method", ["pad", "backfill", "nearest"]) + def test_get_indexer_with_method_numeric_vs_bool(self, method): + left = Index([1, 2, 3]) + right = Index([True, False]) + + with pytest.raises(TypeError, match="Cannot compare"): + left.get_indexer(right, method=method) + + with pytest.raises(TypeError, match="Cannot compare"): + right.get_indexer(left, method=method) + + def test_get_indexer_numeric_vs_bool(self): + left = Index([1, 2, 3]) + right = Index([True, False]) + + res = left.get_indexer(right) + expected = -1 * np.ones(len(right), dtype=np.intp) + tm.assert_numpy_array_equal(res, expected) + + res = right.get_indexer(left) + expected = -1 * np.ones(len(left), dtype=np.intp) + tm.assert_numpy_array_equal(res, expected) + + res = left.get_indexer_non_unique(right)[0] + expected = -1 * np.ones(len(right), dtype=np.intp) + tm.assert_numpy_array_equal(res, expected) + + res = right.get_indexer_non_unique(left)[0] + expected = -1 * np.ones(len(left), dtype=np.intp) + tm.assert_numpy_array_equal(res, expected) + def test_get_indexer_float64(self): idx = Float64Index([0.0, 1.0, 2.0]) tm.assert_numpy_array_equal( diff --git a/pandas/tests/indexes/test_base.py b/pandas/tests/indexes/test_base.py index 1c8b504a6c61c..05bc577d159dc 100644 --- a/pandas/tests/indexes/test_base.py +++ b/pandas/tests/indexes/test_base.py @@ -556,6 +556,17 @@ def test_asof(self, index): d = index[0].to_pydatetime() assert isinstance(index.asof(d), Timestamp) + def test_asof_numeric_vs_bool_raises(self): + left = Index([1, 2, 3]) + right = Index([True, False]) + + msg = "'<' not supported between instances" + with pytest.raises(TypeError, match=msg): + left.asof(right) + + with pytest.raises(TypeError, match=msg): + right.asof(left) + def test_asof_datetime_partial(self): index = date_range("2010-01-01", periods=2, freq="m") expected = Timestamp("2010-02-28") diff --git a/pandas/tests/series/methods/test_reindex.py b/pandas/tests/series/methods/test_reindex.py index 22efc99805983..ecf122679f7ca 100644 --- a/pandas/tests/series/methods/test_reindex.py +++ b/pandas/tests/series/methods/test_reindex.py @@ -293,7 +293,10 @@ def test_reindex_datetimeindexes_tz_naive_and_aware(): idx = date_range("20131101", tz="America/Chicago", periods=7) newidx = date_range("20131103", periods=10, freq="H") s = Series(range(7), index=idx) - msg = "Cannot compare tz-naive and tz-aware timestamps" + msg = ( + r"Cannot compare dtypes datetime64\[ns, America/Chicago\] " + r"and datetime64\[ns\]" + ) with pytest.raises(TypeError, match=msg): s.reindex(newidx, method="ffill")