From ce320aeb7dc5900d4693d84b7ec62d1887bfcee5 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Fri, 17 Jan 2020 17:38:49 -0800 Subject: [PATCH] REF: simplify Float64Index.get_loc --- pandas/core/indexes/numeric.py | 23 +++++++++++---------- pandas/tests/indexes/multi/test_indexing.py | 3 ++- pandas/tests/indexes/test_numeric.py | 3 ++- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index 53f96ace890fb..afcf8523d33d7 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -481,17 +481,18 @@ def __contains__(self, other) -> bool: @Appender(_index_shared_docs["get_loc"]) def get_loc(self, key, method=None, tolerance=None): - try: - if np.all(np.isnan(key)) or is_bool(key): - nan_idxs = self._nan_idxs - try: - return nan_idxs.item() - except ValueError: - if not len(nan_idxs): - raise KeyError(key) - return nan_idxs - except (TypeError, NotImplementedError): - pass + if is_bool(key): + # Catch this to avoid accidentally casting to 1.0 + raise KeyError(key) + + if is_float(key) and np.isnan(key): + nan_idxs = self._nan_idxs + if not len(nan_idxs): + raise KeyError(key) + elif len(nan_idxs) == 1: + return nan_idxs[0] + return nan_idxs + return super().get_loc(key, method=method, tolerance=tolerance) @cache_readonly diff --git a/pandas/tests/indexes/multi/test_indexing.py b/pandas/tests/indexes/multi/test_indexing.py index ad6f06d065150..9070eb3deffb5 100644 --- a/pandas/tests/indexes/multi/test_indexing.py +++ b/pandas/tests/indexes/multi/test_indexing.py @@ -396,7 +396,8 @@ def test_get_loc_missing_nan(): idx.get_loc(3) with pytest.raises(KeyError, match=r"^nan$"): idx.get_loc(np.nan) - with pytest.raises(KeyError, match=r"^\[nan\]$"): + with pytest.raises(TypeError, match=r"'\[nan\]' is an invalid key"): + # listlike/non-hashable raises TypeError idx.get_loc([np.nan]) diff --git a/pandas/tests/indexes/test_numeric.py b/pandas/tests/indexes/test_numeric.py index 582f6c619d287..12cc51222e6bb 100644 --- a/pandas/tests/indexes/test_numeric.py +++ b/pandas/tests/indexes/test_numeric.py @@ -389,7 +389,8 @@ def test_get_loc_missing_nan(self): idx.get_loc(3) with pytest.raises(KeyError, match="^nan$"): idx.get_loc(np.nan) - with pytest.raises(KeyError, match=r"^\[nan\]$"): + with pytest.raises(TypeError, match=r"'\[nan\]' is an invalid key"): + # listlike/non-hashable raises TypeError idx.get_loc([np.nan]) def test_contains_nans(self):