diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 516a271042c9b..d77a37ad355a7 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -44,10 +44,6 @@ from pandas.tseries.offsets import Tick -def _is_convertible_to_td(key): - return isinstance(key, (Tick, timedelta, np.timedelta64, str)) - - def _field_accessor(name, alias, docstring=None): def f(self): values = self.asi8 diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index 1dd5c065ec216..d0a31b68250ad 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -1,5 +1,4 @@ """ implement the TimedeltaIndex """ -from datetime import datetime import numpy as np @@ -10,19 +9,23 @@ _TD_DTYPE, is_float, is_integer, - is_list_like, is_scalar, is_timedelta64_dtype, is_timedelta64_ns_dtype, pandas_dtype, ) -from pandas.core.dtypes.missing import isna +from pandas.core.dtypes.missing import is_valid_nat_for_dtype from pandas.core.accessor import delegate_names from pandas.core.arrays import datetimelike as dtl -from pandas.core.arrays.timedeltas import TimedeltaArray, _is_convertible_to_td +from pandas.core.arrays.timedeltas import TimedeltaArray import pandas.core.common as com -from pandas.core.indexes.base import Index, _index_shared_docs, maybe_extract_name +from pandas.core.indexes.base import ( + Index, + InvalidIndexError, + _index_shared_docs, + maybe_extract_name, +) from pandas.core.indexes.datetimelike import ( DatetimeIndexOpsMixin, DatetimelikeDelegateMixin, @@ -236,22 +239,10 @@ def get_value(self, series, key): Fast lookup of value from 1-dimensional ndarray. Only use this if you know what you're doing """ - - if isinstance(key, str): - try: - key = Timedelta(key) - except ValueError: - raise KeyError(key) - - if isinstance(key, self._data._recognized_scalars) or key is NaT: - key = Timedelta(key) - return self.get_value_maybe_box(series, key) - - value = Index.get_value(self, series, key) - return com.maybe_box(self, value, series, key) - - def get_value_maybe_box(self, series, key: Timedelta): - loc = self.get_loc(key) + if is_integer(key): + loc = key + else: + loc = self.get_loc(key) return self._get_values_for_loc(series, loc) def get_loc(self, key, method=None, tolerance=None): @@ -260,27 +251,31 @@ def get_loc(self, key, method=None, tolerance=None): Returns ------- - loc : int + loc : int, slice, or ndarray[int] """ - if is_list_like(key) or (isinstance(key, datetime) and key is not NaT): - # GH#20464 datetime check here is to ensure we don't allow - # datetime objects to be incorrectly treated as timedelta - # objects; NaT is a special case because it plays a double role - # as Not-A-Timedelta - raise TypeError - - if isna(key): + if not is_scalar(key): + raise InvalidIndexError(key) + + if is_valid_nat_for_dtype(key, self.dtype): key = NaT + elif isinstance(key, str): + try: + key = Timedelta(key) + except ValueError: + raise KeyError(key) + + elif isinstance(key, self._data._recognized_scalars) or key is NaT: + key = Timedelta(key) + + else: + raise KeyError(key) + if tolerance is not None: # try converting tolerance now, so errors don't get swallowed by # the try/except clauses below tolerance = self._convert_tolerance(tolerance, np.asarray(key)) - if _is_convertible_to_td(key) or key is NaT: - key = Timedelta(key) - return Index.get_loc(self, key, method, tolerance) - return Index.get_loc(self, key, method, tolerance) def _maybe_cast_slice_bound(self, label, side, kind): diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index 63a86792082da..0b67ae902b075 100755 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -1608,19 +1608,22 @@ def _convert_to_indexer(self, obj, axis: int, raise_missing: bool = False): is_int_index = labels.is_integer() is_int_positional = is_integer(obj) and not is_int_index - # if we are a label return me - try: - return labels.get_loc(obj) - except LookupError: - if isinstance(obj, tuple) and isinstance(labels, ABCMultiIndex): - if len(obj) == labels.nlevels: - return {"key": obj} - raise - except TypeError: - pass - except ValueError: - if not is_int_positional: - raise + if is_scalar(obj) or isinstance(labels, ABCMultiIndex): + # Otherwise get_loc will raise InvalidIndexError + + # if we are a label return me + try: + return labels.get_loc(obj) + except LookupError: + if isinstance(obj, tuple) and isinstance(labels, ABCMultiIndex): + if len(obj) == labels.nlevels: + return {"key": obj} + raise + except TypeError: + pass + except ValueError: + if not is_int_positional: + raise # a positional if is_int_positional: diff --git a/pandas/tests/indexes/timedeltas/test_indexing.py b/pandas/tests/indexes/timedeltas/test_indexing.py index e8665ee1a3555..14fff6f9c85b5 100644 --- a/pandas/tests/indexes/timedeltas/test_indexing.py +++ b/pandas/tests/indexes/timedeltas/test_indexing.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +import re import numpy as np import pytest @@ -48,12 +49,19 @@ def test_getitem(self): @pytest.mark.parametrize( "key", - [pd.Timestamp("1970-01-01"), pd.Timestamp("1970-01-02"), datetime(1970, 1, 1)], + [ + pd.Timestamp("1970-01-01"), + pd.Timestamp("1970-01-02"), + datetime(1970, 1, 1), + pd.Timestamp("1970-01-03").to_datetime64(), + # non-matching NA values + np.datetime64("NaT"), + ], ) def test_timestamp_invalid_key(self, key): # GH#20464 tdi = pd.timedelta_range(0, periods=10) - with pytest.raises(TypeError): + with pytest.raises(KeyError, match=re.escape(repr(key))): tdi.get_loc(key)