Skip to content

BUG: DTI/TDI/PI where accepting non-matching dtypes #30791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7cb6295
remove ensure_datetimelike_to_i8
jbrockmendel Jan 7, 2020
6d33f16
typo fixup
jbrockmendel Jan 7, 2020
6bc3cdb
handle categories
jbrockmendel Jan 7, 2020
3d69e34
handle categories
jbrockmendel Jan 7, 2020
c6d8558
handle categories
jbrockmendel Jan 7, 2020
c6c8ae5
handle categories
jbrockmendel Jan 7, 2020
975e4aa
handle categories
jbrockmendel Jan 7, 2020
546872a
handle categories
jbrockmendel Jan 7, 2020
4598801
handle categories
jbrockmendel Jan 7, 2020
4b15723
handle categories
jbrockmendel Jan 7, 2020
a4d338c
handle categories
jbrockmendel Jan 7, 2020
fad9b7e
handle None
jbrockmendel Jan 7, 2020
466679e
handle None
jbrockmendel Jan 7, 2020
727afc6
handle None
jbrockmendel Jan 7, 2020
31ed9b8
handle None
jbrockmendel Jan 7, 2020
eaa34b0
handle None
jbrockmendel Jan 7, 2020
ba0e555
handle None
jbrockmendel Jan 7, 2020
24f9327
handle None
jbrockmendel Jan 7, 2020
17fe5bb
handle None
jbrockmendel Jan 7, 2020
39a398c
handle None
jbrockmendel Jan 7, 2020
679aa96
handle None
jbrockmendel Jan 7, 2020
d0a7670
handle None
jbrockmendel Jan 7, 2020
437bcb0
Merge branch 'master' of https://github.com/pandas-dev/pandas into re…
jbrockmendel Jan 7, 2020
f62cb90
simplify
jbrockmendel Jan 7, 2020
bbb0568
simplify
jbrockmendel Jan 7, 2020
08eabab
simplify
jbrockmendel Jan 7, 2020
db99643
simplify
jbrockmendel Jan 7, 2020
6d7a429
fix test
jbrockmendel Jan 7, 2020
ee6db8a
extend test
jbrockmendel Jan 7, 2020
730cbd1
typo fixup
jbrockmendel Jan 7, 2020
2448aa4
tests
jbrockmendel Jan 7, 2020
00d7415
cleanup
jbrockmendel Jan 7, 2020
364765d
remove _ensure_localize
jbrockmendel Jan 7, 2020
4b96ab2
black/isort fixup
jbrockmendel Jan 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 12 additions & 83 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
is_unsigned_integer_dtype,
pandas_dtype,
)
from pandas.core.dtypes.generic import ABCIndexClass, ABCPeriodArray, ABCSeries
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.inference import is_array_like
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna

Expand Down Expand Up @@ -368,16 +368,19 @@ class TimelikeOps:

def _round(self, freq, mode, ambiguous, nonexistent):
# round the local times
values = _ensure_datetimelike_to_i8(self)
if is_datetime64tz_dtype(self):
# operate on naive timestamps, then convert back to aware
naive = self.tz_localize(None)
result = naive._round(freq, mode, ambiguous, nonexistent)
aware = result.tz_localize(
self.tz, ambiguous=ambiguous, nonexistent=nonexistent
)
return aware

values = self.view("i8")
result = round_nsint64(values, mode, freq)
result = self._maybe_mask_results(result, fill_value=NaT)

dtype = self.dtype
if is_datetime64tz_dtype(self):
dtype = None
return self._ensure_localized(
self._simple_new(result, dtype=dtype), ambiguous, nonexistent
)
return self._simple_new(result, dtype=self.dtype)

@Appender((_round_doc + _round_example).format(op="round"))
def round(self, freq, ambiguous="raise", nonexistent="raise"):
Expand Down Expand Up @@ -1411,45 +1414,6 @@ def __isub__(self, other): # type: ignore
self._freq = result._freq
return self

# --------------------------------------------------------------
# Comparison Methods

def _ensure_localized(
self, arg, ambiguous="raise", nonexistent="raise", from_utc=False
):
"""
Ensure that we are re-localized.

This is for compat as we can then call this on all datetimelike
arrays generally (ignored for Period/Timedelta)

Parameters
----------
arg : Union[DatetimeLikeArray, DatetimeIndexOpsMixin, ndarray]
ambiguous : str, bool, or bool-ndarray, default 'raise'
nonexistent : str, default 'raise'
from_utc : bool, default False
If True, localize the i8 ndarray to UTC first before converting to
the appropriate tz. If False, localize directly to the tz.

Returns
-------
localized array
"""

# reconvert to local tz
tz = getattr(self, "tz", None)
if tz is not None:
if not isinstance(arg, type(self)):
arg = self._simple_new(arg)
if from_utc:
arg = arg.tz_localize("UTC").tz_convert(self.tz)
else:
arg = arg.tz_localize(
self.tz, ambiguous=ambiguous, nonexistent=nonexistent
)
return arg

# --------------------------------------------------------------
# Reductions

Expand Down Expand Up @@ -1687,38 +1651,3 @@ def maybe_infer_freq(freq):
freq_infer = True
freq = None
return freq, freq_infer


def _ensure_datetimelike_to_i8(other, to_utc=False):
"""
Helper for coercing an input scalar or array to i8.

Parameters
----------
other : 1d array
to_utc : bool, default False
If True, convert the values to UTC before extracting the i8 values
If False, extract the i8 values directly.

Returns
-------
i8 1d array
"""
from pandas import Index

if lib.is_scalar(other) and isna(other):
return iNaT
elif isinstance(other, (ABCPeriodArray, ABCIndexClass, DatetimeLikeArrayMixin)):
# convert tz if needed
if getattr(other, "tz", None) is not None:
if to_utc:
other = other.tz_convert("UTC")
else:
other = other.tz_localize(None)
else:
try:
return np.array(other, copy=False).view("i8")
except TypeError:
# period array cannot be coerced to int
other = Index(other)
return other.asi8
44 changes: 24 additions & 20 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
from pandas.core.dtypes.common import (
ensure_int64,
is_bool_dtype,
is_categorical_dtype,
is_dtype_equal,
is_float,
is_integer,
is_list_like,
is_period_dtype,
is_scalar,
needs_i8_conversion,
)
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
from pandas.core.dtypes.missing import isna

from pandas.core import algorithms
from pandas.core.accessor import PandasDelegate
Expand All @@ -34,10 +37,7 @@
ExtensionOpsMixin,
TimedeltaArray,
)
from pandas.core.arrays.datetimelike import (
DatetimeLikeArrayMixin,
_ensure_datetimelike_to_i8,
)
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
import pandas.core.indexes.base as ibase
from pandas.core.indexes.base import Index, _index_shared_docs
from pandas.core.indexes.numeric import Int64Index
Expand Down Expand Up @@ -177,18 +177,6 @@ def equals(self, other) -> bool:

return np.array_equal(self.asi8, other.asi8)

def _ensure_localized(
self, arg, ambiguous="raise", nonexistent="raise", from_utc=False
):
# See DatetimeLikeArrayMixin._ensure_localized.__doc__
if getattr(self, "tz", None):
# ensure_localized is only relevant for tz-aware DTI
result = self._data._ensure_localized(
arg, ambiguous=ambiguous, nonexistent=nonexistent, from_utc=from_utc
)
return type(self)._simple_new(result, name=self.name)
return arg

@Appender(_index_shared_docs["contains"] % _index_doc_kwargs)
def __contains__(self, key):
try:
Expand Down Expand Up @@ -491,11 +479,27 @@ def repeat(self, repeats, axis=None):

@Appender(_index_shared_docs["where"] % _index_doc_kwargs)
def where(self, cond, other=None):
other = _ensure_datetimelike_to_i8(other, to_utc=True)
values = _ensure_datetimelike_to_i8(self, to_utc=True)
result = np.where(cond, values, other).astype("i8")
values = self.view("i8")

if is_scalar(other) and isna(other):
other = NaT.value

result = self._ensure_localized(result, from_utc=True)
else:
# Do type inference if necessary up front
# e.g. we passed PeriodIndex.values and got an ndarray of Periods
other = Index(other)

if is_categorical_dtype(other):
# e.g. we have a Categorical holding self.dtype
if needs_i8_conversion(other.categories):
other = other._internal_get_values()

if not is_dtype_equal(self.dtype, other.dtype):
raise TypeError(f"Where requires matching dtype, not {other.dtype}")

other = other.view("i8")

result = np.where(cond, values, other).astype("i8")
return self._shallow_copy(result)

def _summary(self, name=None):
Expand Down
25 changes: 24 additions & 1 deletion pandas/tests/indexes/datetimes/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,32 @@ def test_where_other(self):

i2 = i.copy()
i2 = Index([pd.NaT, pd.NaT] + i[2:].tolist())
result = i.where(notna(i2), i2.values)
result = i.where(notna(i2), i2._values)
tm.assert_index_equal(result, i2)

def test_where_invalid_dtypes(self):
dti = pd.date_range("20130101", periods=3, tz="US/Eastern")

i2 = dti.copy()
i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
# passing tz-naive ndarray to tzaware DTI
dti.where(notna(i2), i2.values)

with pytest.raises(TypeError, match="Where requires matching dtype"):
# passing tz-aware DTI to tznaive DTI
dti.tz_localize(None).where(notna(i2), i2)

with pytest.raises(TypeError, match="Where requires matching dtype"):
dti.where(notna(i2), i2.tz_localize(None).to_period("D"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
dti.where(notna(i2), i2.asi8.view("timedelta64[ns]"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
dti.where(notna(i2), i2.asi8)

def test_where_tz(self):
i = pd.date_range("20130101", periods=3, tz="US/Eastern")
result = i.where(notna(i))
Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/indexes/period/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,21 @@ def test_where_other(self):
result = i.where(notna(i2), i2.values)
tm.assert_index_equal(result, i2)

def test_where_invalid_dtypes(self):
pi = period_range("20130101", periods=5, freq="D")

i2 = pi.copy()
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + pi[2:].tolist(), freq="D")

with pytest.raises(TypeError, match="Where requires matching dtype"):
pi.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
pi.where(notna(i2), i2.asi8.view("timedelta64[ns]"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
pi.where(notna(i2), i2.to_timestamp("S"))


class TestTake:
def test_take(self):
Expand Down
18 changes: 15 additions & 3 deletions pandas/tests/indexes/timedeltas/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import pandas as pd
from pandas import Index, Timedelta, TimedeltaIndex, timedelta_range
from pandas import Index, Timedelta, TimedeltaIndex, notna, timedelta_range
import pandas._testing as tm


Expand Down Expand Up @@ -58,8 +58,20 @@ def test_timestamp_invalid_key(self, key):


class TestWhere:
# placeholder for symmetry with DatetimeIndex and PeriodIndex tests
pass
def test_where_invalid_dtypes(self):
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")

i2 = tdi.copy()
i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
tdi.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
tdi.where(notna(i2), i2 + pd.Timestamp.now())

with pytest.raises(TypeError, match="Where requires matching dtype"):
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))


class TestTake:
Expand Down