Skip to content

REF: do all casting _before_ call to DatetimeEngine.get_loc #30948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ cnp.import_array()

cimport pandas._libs.util as util

from pandas._libs.tslibs.conversion cimport maybe_datetimelike_to_i8
from pandas._libs.tslibs.nattype cimport c_NaT as NaT
from pandas._libs.tslibs.c_timestamp cimport _Timestamp

from pandas._libs.hashtable cimport HashTable

Expand Down Expand Up @@ -409,20 +409,27 @@ cdef class DatetimeEngine(Int64Engine):
cdef _get_box_dtype(self):
return 'M8[ns]'

cdef int64_t _unbox_scalar(self, scalar) except? -1:
# NB: caller is responsible for ensuring tzawareness compat
# before we get here
if not (isinstance(scalar, _Timestamp) or scalar is NaT):
raise TypeError(scalar)
return scalar.value

def __contains__(self, object val):
cdef:
int64_t loc
int64_t loc, conv

conv = self._unbox_scalar(val)
if self.over_size_threshold and self.is_monotonic_increasing:
if not self.is_unique:
return self._get_loc_duplicates(val)
return self._get_loc_duplicates(conv)
values = self._get_index_values()
conv = maybe_datetimelike_to_i8(val)
loc = values.searchsorted(conv, side='left')
return values[loc] == conv

self._ensure_mapping_populated()
return maybe_datetimelike_to_i8(val) in self.mapping
return conv in self.mapping

cdef _get_index_values(self):
return self.vgetter().view('i8')
Expand All @@ -431,45 +438,39 @@ cdef class DatetimeEngine(Int64Engine):
return algos.is_monotonic(values, timelike=True)

cpdef get_loc(self, object val):
# NB: the caller is responsible for ensuring that we are called
# with either a Timestamp or NaT (Timedelta or NaT for TimedeltaEngine)

cdef:
int64_t loc
if is_definitely_invalid_key(val):
raise TypeError

try:
conv = self._unbox_scalar(val)
except TypeError:
raise KeyError(val)

# Welcome to the spaghetti factory
if self.over_size_threshold and self.is_monotonic_increasing:
if not self.is_unique:
val = maybe_datetimelike_to_i8(val)
return self._get_loc_duplicates(val)
return self._get_loc_duplicates(conv)
values = self._get_index_values()

try:
conv = maybe_datetimelike_to_i8(val)
loc = values.searchsorted(conv, side='left')
except TypeError:
raise KeyError(val)
loc = values.searchsorted(conv, side='left')

if loc == len(values) or values[loc] != conv:
raise KeyError(val)
return loc

self._ensure_mapping_populated()
if not self.unique:
val = maybe_datetimelike_to_i8(val)
return self._get_loc_duplicates(val)
return self._get_loc_duplicates(conv)

try:
return self.mapping.get_item(val.value)
return self.mapping.get_item(conv)
except KeyError:
raise KeyError(val)
except AttributeError:
pass

try:
val = maybe_datetimelike_to_i8(val)
return self.mapping.get_item(val)
except (TypeError, ValueError):
raise KeyError(val)

def get_indexer(self, values):
self._ensure_mapping_populated()
Expand All @@ -496,6 +497,11 @@ cdef class TimedeltaEngine(DatetimeEngine):
cdef _get_box_dtype(self):
return 'm8[ns]'

cdef int64_t _unbox_scalar(self, scalar) except? -1:
if not (isinstance(scalar, Timedelta) or scalar is NaT):
raise TypeError(scalar)
return scalar.value


cdef class PeriodEngine(Int64Engine):

Expand Down
2 changes: 0 additions & 2 deletions pandas/_libs/tslibs/conversion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,4 @@ cdef int64_t get_datetime64_nanos(object val) except? -1

cpdef int64_t pydt_to_i8(object pydt) except? -1

cdef maybe_datetimelike_to_i8(object val)

cpdef datetime localize_pydatetime(datetime dt, object tz)
25 changes: 0 additions & 25 deletions pandas/_libs/tslibs/conversion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -202,31 +202,6 @@ def datetime_to_datetime64(object[:] values):
return result, inferred_tz


cdef inline maybe_datetimelike_to_i8(object val):
"""
Try to convert to a nanosecond timestamp. Fall back to returning the
input value.

Parameters
----------
val : object

Returns
-------
val : int64 timestamp or original input
"""
cdef:
npy_datetimestruct dts
try:
return val.value
except AttributeError:
if is_datetime64_object(val):
return get_datetime64_value(val)
elif PyDateTime_Check(val):
return convert_datetime_to_tsobject(val, None).value
return val


# ----------------------------------------------------------------------
# _TSObject Conversion

Expand Down
53 changes: 53 additions & 0 deletions pandas/tests/indexes/test_engines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

import pandas as pd


class TestDatetimeEngine:
@pytest.mark.parametrize(
"scalar",
[
pd.Timedelta(pd.Timestamp("2016-01-01").asm8.view("m8[ns]")),
pd.Timestamp("2016-01-01").value,
pd.Timestamp("2016-01-01").to_pydatetime(),
pd.Timestamp("2016-01-01").to_datetime64(),
],
)
def test_not_contains_requires_timestamp(self, scalar):
dti1 = pd.date_range("2016-01-01", periods=3)
dti2 = dti1.insert(1, pd.NaT) # non-monotonic
dti3 = dti1.insert(3, dti1[0]) # non-unique
dti4 = pd.date_range("2016-01-01", freq="ns", periods=2_000_000)
dti5 = dti4.insert(0, dti4[0]) # over size threshold, not unique

for dti in [dti1, dti2, dti3, dti4, dti5]:
with pytest.raises(TypeError):
scalar in dti._engine

with pytest.raises(KeyError):
dti._engine.get_loc(scalar)


class TestTimedeltaEngine:
@pytest.mark.parametrize(
"scalar",
[
pd.Timestamp(pd.Timedelta(days=42).asm8.view("datetime64[ns]")),
pd.Timedelta(days=42).value,
pd.Timedelta(days=42).to_pytimedelta(),
pd.Timedelta(days=42).to_timedelta64(),
],
)
def test_not_contains_requires_timestamp(self, scalar):
tdi1 = pd.timedelta_range("42 days", freq="9h", periods=1234)
tdi2 = tdi1.insert(1, pd.NaT) # non-monotonic
tdi3 = tdi1.insert(3, tdi1[0]) # non-unique
tdi4 = pd.timedelta_range("42 days", freq="ns", periods=2_000_000)
tdi5 = tdi4.insert(0, tdi4[0]) # over size threshold, not unique

for tdi in [tdi1, tdi2, tdi3, tdi4, tdi5]:
with pytest.raises(TypeError):
scalar in tdi._engine

with pytest.raises(KeyError):
tdi._engine.get_loc(scalar)