Skip to content

BUG: Index.where casting ints to str #37591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 3, 2020
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ Indexing
- Bug in :meth:`Series.loc.__getitem__` with a non-unique :class:`MultiIndex` and an empty-list indexer (:issue:`13691`)
- Bug in indexing on a :class:`Series` or :class:`DataFrame` with a :class:`MultiIndex` with a level named "0" (:issue:`37194`)
- Bug in :meth:`Series.__getitem__` when using an unsigned integer array as an indexer giving incorrect results or segfaulting instead of raising ``KeyError`` (:issue:`37218`)
- Bug in :meth:`Index.where` incorrectly casting numeric values to strings (:issue:`37591`)

Missing
^^^^^^^
Expand Down
18 changes: 5 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
ensure_int64,
ensure_object,
ensure_platform_int,
is_bool,
is_bool_dtype,
is_categorical_dtype,
is_datetime64_any_dtype,
Expand Down Expand Up @@ -4079,23 +4078,16 @@ def where(self, cond, other=None):
if other is None:
other = self._na_value

dtype = self.dtype
values = self.values

if is_bool(other) or is_bool_dtype(other):

# bools force casting
values = values.astype(object)
dtype = None
try:
self._validate_fill_value(other)
except (ValueError, TypeError):
return self.astype(object).where(cond, other)

values = np.where(cond, values, other)

if self._is_numeric_dtype and np.any(isna(values)):
# We can't coerce to the numeric dtype of "self" (unless
# it's float) if there are NaN values in our output.
dtype = None

return Index(values, dtype=dtype, name=self.name)
return Index(values, name=self.name)

# construction helpers
@final
Expand Down
11 changes: 2 additions & 9 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,9 @@ def isin(self, values, level=None):

@Appender(Index.where.__doc__)
def where(self, cond, other=None):
values = self._data._ndarray
other = self._data._validate_setitem_value(other)

try:
other = self._data._validate_setitem_value(other)
except (TypeError, ValueError) as err:
# Includes tzawareness mismatch and IncompatibleFrequencyError
oth = getattr(other, "dtype", other)
raise TypeError(f"Where requires matching dtype, not {oth}") from err

result = np.where(cond, values, other)
result = np.where(cond, self._data._ndarray, other)
arr = self._data._from_backing_data(result)
return type(self)._simple_new(arr, name=self.name)

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def _validate_fill_value(self, value):
# force conversion to object
# so we don't lose the bools
raise TypeError
if isinstance(value, str):
raise TypeError

return value

Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/indexes/base_class/test_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np

from pandas import Index
import pandas._testing as tm


class TestWhere:
def test_where_intlike_str_doesnt_cast_ints(self):
idx = Index(range(3))
mask = np.array([True, False, True])
res = idx.where(mask, "2")
expected = Index([0, "2", 2])
tm.assert_index_equal(res, expected)
3 changes: 1 addition & 2 deletions pandas/tests/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,9 @@ def test_where_cast_str(self):
result = index.where(mask, [str(index[0])])
tm.assert_index_equal(result, expected)

msg = "Where requires matching dtype, not foo"
msg = "value should be a '.*', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
index.where(mask, "foo")

msg = r"Where requires matching dtype, not \['foo'\]"
with pytest.raises(TypeError, match=msg):
index.where(mask, ["foo"])
16 changes: 9 additions & 7 deletions pandas/tests/indexes/datetimes/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,26 @@ def test_where_invalid_dtypes(self):

i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
msg2 = "Cannot compare tz-naive and tz-aware datetime-like objects"
with pytest.raises(TypeError, match=msg2):
# passing tz-naive ndarray to tzaware DTI
dti.where(notna(i2), i2.values)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg2):
# passing tz-aware DTI to tznaive DTI
dti.tz_localize(None).where(notna(i2), i2)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
dti.where(notna(i2), i2.tz_localize(None).to_period("D"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
dti.where(notna(i2), i2.asi8.view("timedelta64[ns]"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
dti.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
# non-matching scalar
dti.where(notna(i2), pd.Timedelta(days=4))

Expand All @@ -203,7 +205,7 @@ def test_where_mismatched_nat(self, tz_aware_fixture):
dti = pd.date_range("2013-01-01", periods=3, tz=tz)
cond = np.array([True, False, True])

msg = "Where requires matching dtype"
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
dti.where(cond, np.timedelta64("NaT", "ns"))
Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/indexes/period/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,24 +545,25 @@ def test_where_invalid_dtypes(self):

i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D")

with pytest.raises(TypeError, match="Where requires matching dtype"):
msg = "value should be a 'Period', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
pi.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
pi.where(notna(i2), i2.asi8.view("timedelta64[ns]"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
pi.where(notna(i2), i2.to_timestamp("S"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
# non-matching scalar
pi.where(notna(i2), Timedelta(days=4))

def test_where_mismatched_nat(self):
pi = period_range("20130101", periods=5, freq="D")
cond = np.array([True, False, True, True, False])

msg = "Where requires matching dtype"
msg = "value should be a 'Period', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
pi.where(cond, np.timedelta64("NaT", "ns"))
Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/indexes/timedeltas/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,25 @@ def test_where_invalid_dtypes(self):

i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
msg = "value should be a 'Timedelta', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
tdi.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
tdi.where(notna(i2), i2 + pd.Timestamp.now())

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
# non-matching scalar
tdi.where(notna(i2), pd.Timestamp.now())

def test_where_mismatched_nat(self):
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
cond = np.array([True, False, False])

msg = "Where requires matching dtype"
msg = "value should be a 'Timedelta', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
tdi.where(cond, np.datetime64("NaT", "ns"))
Expand Down
7 changes: 4 additions & 3 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def test_where_index_timedelta64(self, value):
result = tdi.where(cond, value)
tm.assert_index_equal(result, expected)

msg = "Where requires matching dtype"
msg = "value should be a 'Timedelta', 'NaT', or array of thos"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
tdi.where(cond, np.datetime64("NaT", "ns"))
Expand All @@ -804,11 +804,12 @@ def test_where_index_period(self):
tm.assert_index_equal(result, expected)

# Passing a mismatched scalar
msg = "Where requires matching dtype"
msg = "value should be a 'Period', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
pi.where(cond, pd.Timedelta(days=4))

with pytest.raises(TypeError, match=msg):
msg = r"Input has different freq=D from PeriodArray\(freq=Q-DEC\)"
with pytest.raises(ValueError, match=msg):
pi.where(cond, pd.Period("2020-04-21", "D"))


Expand Down